diff --git "a/docs/source/Megatron-SWIFT/\345\277\253\351\200\237\345\274\200\345\247\213.md" "b/docs/source/Megatron-SWIFT/\345\277\253\351\200\237\345\274\200\345\247\213.md" index f6e520c0da..bec829723e 100644 --- "a/docs/source/Megatron-SWIFT/\345\277\253\351\200\237\345\274\200\345\247\213.md" +++ "b/docs/source/Megatron-SWIFT/\345\277\253\351\200\237\345\274\200\345\247\213.md" @@ -65,7 +65,7 @@ modelscope-registry.us-west-1.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu2 | torch | >=2.0 | 2.6.0/2.7.1 | | | transformer_engine | >=2.3 | | | | apex | | 0.1 | | -| megatron_core | >=0.12 | 0.13 | | +| megatron_core | | 0.13 | | | flash_attn | | 2.8.1/3.0.0b1 | | | transformers | >=4.33 | 4.56.2 | | | modelscope | >=1.23 | | | diff --git a/docs/source_en/Megatron-SWIFT/Quick-start.md b/docs/source_en/Megatron-SWIFT/Quick-start.md index 03b1558f3d..59ac888b65 100644 --- a/docs/source_en/Megatron-SWIFT/Quick-start.md +++ b/docs/source_en/Megatron-SWIFT/Quick-start.md @@ -65,7 +65,7 @@ Recommended Operating Environment: | torch | >=2.0 | 2.6.0/2.7.1 | | | transformer_engine | >=2.3 | | | | apex | | 0.1 | | -| megatron_core | >=0.12 | 0.13 | | +| megatron_core | | 0.13 | | | flash_attn | | 2.8.1/3.0.0b1 | | | transformers | >=4.33 | 4.56.2 | | | modelscope | >=1.23 | | | diff --git a/swift/cli/_megatron/export.py b/swift/cli/_megatron/export.py new file mode 100644 index 0000000000..3eca73ca8c --- /dev/null +++ b/swift/cli/_megatron/export.py @@ -0,0 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from swift.megatron import megatron_export_main + +if __name__ == '__main__': + megatron_export_main() diff --git a/swift/cli/_megatron/main.py b/swift/cli/_megatron/main.py index a83dcac8b1..0fc2de0c92 100644 --- a/swift/cli/_megatron/main.py +++ b/swift/cli/_megatron/main.py @@ -10,11 +10,12 @@ 'pt': 'swift.cli._megatron.pt', 'sft': 'swift.cli._megatron.sft', 'rlhf': 'swift.cli._megatron.rlhf', + 'export': 'swift.cli._megatron.export', } def cli_main(): - return swift_cli_main(ROUTE_MAPPING) + return swift_cli_main(ROUTE_MAPPING, is_megatron=True) if __name__ == '__main__': diff --git a/swift/cli/main.py b/swift/cli/main.py index 7637a6d4ff..0e4367e59c 100644 --- a/swift/cli/main.py +++ b/swift/cli/main.py @@ -91,7 +91,7 @@ def _compat_web_ui(argv): logger.warning('Please use `swift app`.') -def cli_main(route_mapping: Optional[Dict[str, str]] = None) -> None: +def cli_main(route_mapping: Optional[Dict[str, str]] = None, is_megatron: bool = False) -> None: route_mapping = route_mapping or ROUTE_MAPPING argv = sys.argv[1:] _compat_web_ui(argv) @@ -101,7 +101,7 @@ def cli_main(route_mapping: Optional[Dict[str, str]] = None) -> None: torchrun_args = get_torchrun_args() prepare_config_args(argv) python_cmd = sys.executable - if torchrun_args is None or method_name not in {'pt', 'sft', 'rlhf', 'infer'}: + if not is_megatron and (torchrun_args is None or method_name not in {'pt', 'sft', 'rlhf', 'infer'}): args = [python_cmd, file_path, *argv] else: args = [python_cmd, '-m', 'torch.distributed.run', *torchrun_args, file_path, *argv] diff --git a/swift/llm/argument/base_args/base_args.py b/swift/llm/argument/base_args/base_args.py index 76fc748776..31c647224f 100644 --- a/swift/llm/argument/base_args/base_args.py +++ b/swift/llm/argument/base_args/base_args.py @@ -220,7 +220,10 @@ def from_pretrained(cls, checkpoint_dir: str): def _init_ckpt_dir(self, adapters=None): # compat megatron model = self.model or getattr(self, 'mcore_model', None) or getattr(self, 'load', None) - adapters = adapters or self.adapters or getattr(self, 'mcore_adapters', None) + adapters = adapters or self.adapters or getattr(self, 'mcore_adapters', None) or getattr( + self, 'adapter_load', None) + if isinstance(adapters, str): + adapters = [adapters] self.ckpt_dir = get_ckpt_dir(model, adapters) if self.ckpt_dir and self.load_args: self.load_args_from_ckpt() @@ -308,12 +311,13 @@ def get_model_processor(self, **kwargs): if self.tuner_backend == 'unsloth': return load_by_unsloth(self) - kwargs.update(self.get_model_kwargs()) + res = self.get_model_kwargs() + res.update(kwargs) # compat rlhf - kwargs['model_id_or_path'] = model or self.model - kwargs['model_type'] = model_type or self.model_type - kwargs['model_revision'] = model_revision or self.model_revision - kwargs['task_type'] = task_type or self.task_type - kwargs['num_labels'] = num_labels or self.num_labels + res['model_id_or_path'] = model or self.model + res['model_type'] = model_type or self.model_type + res['model_revision'] = model_revision or self.model_revision + res['task_type'] = task_type or self.task_type + res['num_labels'] = num_labels or self.num_labels - return get_model_tokenizer(**kwargs) + return get_model_tokenizer(**res) diff --git a/swift/llm/infer/utils.py b/swift/llm/infer/utils.py index 49f35fc646..49bc8b53a2 100644 --- a/swift/llm/infer/utils.py +++ b/swift/llm/infer/utils.py @@ -142,11 +142,12 @@ def prepare_adapter(args, model, adapters=None): def prepare_model_template(args, **kwargs): + adapters = kwargs.get('adapters') model, processor = args.get_model_processor(**kwargs) template = args.get_template(processor) if model is not None: if template.use_model: template.model = model - model = prepare_adapter(args, model) + model = prepare_adapter(args, model, adapters=adapters) update_generation_config_eos_token(model.generation_config, template) return model, template diff --git a/swift/llm/model/patcher.py b/swift/llm/model/patcher.py index 1ab31aa8f7..58e2065f1e 100644 --- a/swift/llm/model/patcher.py +++ b/swift/llm/model/patcher.py @@ -467,7 +467,7 @@ def patch_tp_plan(load_model: bool): transformers.__version__) < version.parse('4.50') or 'WORLD_SIZE' not in os.environ: yield return - logger.info('Patch tp_plan.') + logger.info_once('Patch tp_plan.') WORLD_SIZE = os.environ.get('WORLD_SIZE') os.environ['_PATCH_WORLD_SIZE'] = WORLD_SIZE os.environ.pop('WORLD_SIZE') diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index 5598200ac0..fce20eb7d2 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -1687,6 +1687,7 @@ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[in assert res['attention_mask'].dtype is torch.bool, f'attention_mask.dtype: {res["attention_mask"].dtype}' for i, seq_len in enumerate(seq_lens): res['attention_mask'][i, :, seq_len:] = 0 + res['attention_mask'] = ~res['attention_mask'] for key, pad_value in zip(keys, pad_values): if key not in res: diff --git a/swift/megatron/__init__.py b/swift/megatron/__init__.py index 0a5a41ebc1..9815a0ea87 100644 --- a/swift/megatron/__init__.py +++ b/swift/megatron/__init__.py @@ -13,16 +13,20 @@ if TYPE_CHECKING: from .train import megatron_sft_main, megatron_pt_main, megatron_rlhf_main - from .utils import convert_hf2mcore, convert_mcore2hf, prepare_mcore_model, adapter_state_dict_context - from .argument import MegatronTrainArguments, MegatronRLHFArguments + from .export import megatron_export_main + from .convert import convert_hf2mcore, convert_mcore2hf + from .utils import prepare_mcore_model, adapter_state_dict_context + from .argument import MegatronTrainArguments, MegatronRLHFArguments, MegatronExportArguments from .model import MegatronModelType, MegatronModelMeta, get_megatron_model_meta, register_megatron_model from .trainers import MegatronTrainer, MegatronDPOTrainer from .tuners import LoraParallelLinear else: _import_structure = { 'train': ['megatron_sft_main', 'megatron_pt_main', 'megatron_rlhf_main'], - 'utils': ['convert_hf2mcore', 'convert_mcore2hf', 'prepare_mcore_model', 'adapter_state_dict_context'], - 'argument': ['MegatronTrainArguments', 'MegatronRLHFArguments'], + 'export': ['megatron_export_main'], + 'convert': ['convert_hf2mcore', 'convert_mcore2hf'], + 'utils': ['prepare_mcore_model', 'adapter_state_dict_context'], + 'argument': ['MegatronTrainArguments', 'MegatronRLHFArguments', 'MegatronExportArguments'], 'model': ['MegatronModelType', 'MegatronModelMeta', 'get_megatron_model_meta', 'register_megatron_model'], 'trainers': ['MegatronTrainer', 'MegatronDPOTrainer'], 'tuners': ['LoraParallelLinear'], diff --git a/swift/megatron/argument/__init__.py b/swift/megatron/argument/__init__.py index a2ad08daa3..e6d577e1ad 100644 --- a/swift/megatron/argument/__init__.py +++ b/swift/megatron/argument/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +from .export_args import MegatronExportArguments from .megatron_args import MegatronArguments from .rlhf_args import MegatronRLHFArguments from .train_args import MegatronTrainArguments diff --git a/swift/megatron/argument/export_args.py b/swift/megatron/argument/export_args.py new file mode 100644 index 0000000000..e3492e6371 --- /dev/null +++ b/swift/megatron/argument/export_args.py @@ -0,0 +1,64 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from dataclasses import dataclass + +from swift.llm import HfConfigFactory +from swift.llm.argument.base_args import to_abspath +from swift.utils import get_logger +from .megatron_args import MegatronArguments +from .megatron_base_args import MegatronBaseArguments + +logger = get_logger() + + +@dataclass +class MegatronExportArguments(MegatronBaseArguments): + to_hf: bool = False + to_mcore: bool = False + test_convert_precision: bool = False + test_convert_dtype: str = 'float32' + exist_ok: bool = False + + def _init_save(self): + if self.save is None: + ckpt_dir = self.ckpt_dir or f'./{self.model_suffix}' + ckpt_dir, ckpt_name = os.path.split(ckpt_dir) + if self.to_mcore: + suffix = 'mcore' + elif self.to_hf: + suffix = 'hf' + self.save = os.path.join(ckpt_dir, f'{ckpt_name}-{suffix}') + + self.save = to_abspath(self.save) + if not self.exist_ok and os.path.exists(self.save): + raise FileExistsError(f'args.save: `{self.save}` already exists.') + logger.info(f'args.save: `{self.save}`') + + def __post_init__(self): + super().__post_init__() + self._init_save() + self.test_convert_dtype = HfConfigFactory.to_torch_dtype(self.test_convert_dtype) + extra_config = MegatronArguments.load_args_config(self.adapter_load or self.load) + extra_config['adapter_load'] = self.adapter_load + for k, v in extra_config.items(): + setattr(self, k, v) + if self.to_hf or self.to_mcore: + self._init_convert() + if self.model_info.is_moe_model is not None and self.tensor_model_parallel_size > 1: + self.sequence_parallel = True + logger.info('Settting args.sequence_parallel: True') + + def _init_convert(self): + convert_kwargs = { + 'no_save_optim': True, + 'no_save_rng': True, + 'no_load_optim': True, + 'no_load_rng': True, + 'finetune': True, + 'attention_backend': 'unfused', + 'padding_free': False, + } + for k, v in convert_kwargs.items(): + setattr(self, k, v) + if self.model_info.is_moe_model: + self.moe_grouped_gemm = True diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 4cd2fc7a18..203167ec32 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -9,7 +9,6 @@ import torch from transformers.utils.versions import require_version -from swift.llm.argument.base_args import to_abspath from swift.utils import get_dist_setting, get_logger, json_parse_to_dict logger = get_logger() @@ -94,6 +93,11 @@ class ExtraMegatronArguments(RLHFMegatronArgumentsMixin, MegatronTunerMixin): torch_dtype: Optional[torch.dtype] = None padding_free: bool = True mlp_padding_free: bool = False + load_hf_checkpoint: bool = False + save_hf_checkpoint: bool = False + model: Optional[str] = None + adapters: List[str] = field(default_factory=list) + merge_lora: Optional[bool] = None # streaming dataloader dataloader_persistent_workers: bool = True dataloader_prefetch_factor: int = 10 @@ -123,6 +127,8 @@ class ExtraMegatronArguments(RLHFMegatronArgumentsMixin, MegatronTunerMixin): layer_types: Optional[List[str]] = None # qwen3_vl, qwen3_omni mrope_interleaved: Optional[bool] = None + # hf saver + max_shard_size: str = '5GB' @staticmethod def load_args_config(ckpt_dir: Optional[str]) -> Dict[str, Any]: @@ -452,12 +458,17 @@ def __post_init__(self): self.seq_length = self.max_position_embeddings if self.position_embedding_type is None: self.position_embedding_type = 'rope' - if self.tensorboard_dir is None and self.save is not None: - self.tensorboard_dir = f'{self.save}/runs' + if self.merge_lora is None: + self.merge_lora = self.save_hf_checkpoint + if self.adapters or self.adapter_load or self.ref_adapter_load: + if self.train_type == 'full': + self.train_type = 'lora' + logger.info('Setting args.train_type: lora') + if self.adapters: + self._load_adapter_config() self._init_moe() self._init_mixed_precision() - self.tensorboard_dir = to_abspath(self.tensorboard_dir) self.megatron_extra_kwargs = json_parse_to_dict(self.megatron_extra_kwargs) self._init_no_rope_fusion() @@ -498,3 +509,22 @@ def parse_to_megatron(self): # parameter conflict extra_args.pop('loss_scale', None) return extra_args + + def _load_adapter_config(self): + assert len(self.adapters) == 1, 'Currently only support one adapter' + adapter_path = self.adapters[0] + adapter_config_path = os.path.join(adapter_path, 'adapter_config.json') + adapter_config = {} + if os.path.exists(adapter_config_path): + with open(adapter_config_path, 'r') as f: + adapter_config = json.load(f) + mapping = {'r': 'lora_rank', 'bias': 'lora_bias'} + for k in ['lora_alpha', 'lora_dropout', 'use_rslora']: + mapping[k] = k + for k, v in adapter_config.items(): + if k not in mapping: + continue + k = mapping[k] + if v != getattr(self, k): + setattr(self, k, v) + logger.info(f'Setting {k}: {v}') diff --git a/swift/megatron/argument/megatron_base_args.py b/swift/megatron/argument/megatron_base_args.py new file mode 100644 index 0000000000..0f496aac42 --- /dev/null +++ b/swift/megatron/argument/megatron_base_args.py @@ -0,0 +1,47 @@ +import os +from dataclasses import dataclass + +import json + +from swift.llm import BaseArguments +from swift.utils import get_logger +from ..model import get_megatron_model_meta +from ..utils import convert_hf_config +from .megatron_args import MegatronArguments + +logger = get_logger() + + +@dataclass +class MegatronBaseArguments(MegatronArguments, BaseArguments): + + def __post_init__(self): + self.sequence_parallel_size = self.context_parallel_size + if self.packing: + self.padding_free = True + BaseArguments.__post_init__(self) + self.megatron_model_meta = get_megatron_model_meta(self.model_type) + self.seq_length = self.seq_length or self.packing_length or self.max_length + if self.streaming: + self.dataloader_type = 'external' + if self.num_workers > 1: + self.num_workers = 1 + logger.info('Using streaming dataset, setting args.num_workers to 1.') + + def init_model_args(self, tokenizer, config): + if self.task_type == 'seq_cls': + self.problem_type = self.problem_type or getattr(config, 'problem_type', None) + logger.info(f'args.problem_type: {self.problem_type}') + kwargs = convert_hf_config(config) + if self.new_special_tokens and kwargs['padded_vocab_size'] < len(tokenizer): + kwargs['padded_vocab_size'] = math.ceil(len(tokenizer) / 128) * 128 + self.initialize_embedding = True + logger.info(f'megatron_config: {kwargs}') + for k, v in kwargs.items(): + if getattr(self, k) is None: + setattr(self, k, v) + MegatronArguments.__post_init__(self) + self.extra_args = self.parse_to_megatron() + self.extra_args['model_info'] = self.model_info + self.extra_args['model_meta'] = self.model_meta + self.extra_args['megatron_model_meta'] = self.megatron_model_meta diff --git a/swift/megatron/argument/train_args.py b/swift/megatron/argument/train_args.py index 61535def15..b0c85bb20f 100644 --- a/swift/megatron/argument/train_args.py +++ b/swift/megatron/argument/train_args.py @@ -7,36 +7,17 @@ from swift.llm import BaseArguments from swift.llm.argument.base_args import to_abspath -from swift.utils import add_version_to_work_dir, get_logger, init_process_group, is_master -from ..model import get_megatron_model_meta -from .megatron_args import MegatronArguments +from swift.utils import add_version_to_work_dir, get_logger, init_process_group, is_last_rank +from .megatron_base_args import MegatronBaseArguments logger = get_logger() @dataclass -class MegatronTrainArguments(MegatronArguments, BaseArguments): +class MegatronTrainArguments(MegatronBaseArguments): add_version: bool = True load_args: bool = False - def init_model_args(self, tokenizer, config): - if self.task_type == 'seq_cls': - self.problem_type = self.problem_type or getattr(config, 'problem_type', None) - logger.info(f'args.problem_type: {self.problem_type}') - kwargs = self.megatron_model_meta.convert_hf_config(config) - if self.new_special_tokens and kwargs['padded_vocab_size'] < len(tokenizer): - kwargs['padded_vocab_size'] = math.ceil(len(tokenizer) / 128) * 128 - self.initialize_embedding = True - logger.info(f'megatron_config: {kwargs}') - for k, v in kwargs.items(): - if getattr(self, k) is None: - setattr(self, k, v) - MegatronArguments.__post_init__(self) - self.extra_args = self.parse_to_megatron() - self.extra_args['model_info'] = self.model_info - self.extra_args['model_meta'] = self.model_meta - self.extra_args['megatron_model_meta'] = self.megatron_model_meta - def _init_save(self): init_process_group(backend=self.ddp_backend, timeout=self.ddp_timeout) if self.save is None: @@ -45,7 +26,7 @@ def _init_save(self): if self.add_version: self.save = add_version_to_work_dir(self.save) logger.info(f'args.save: {self.save}') - if is_master(): + if is_last_rank(): os.makedirs(self.save, exist_ok=True) def _init_ckpt_dir(self, adapters=None): @@ -59,24 +40,22 @@ def _init_ckpt_dir(self, adapters=None): self.model = old_args.get('model') def __post_init__(self): - self.sequence_parallel_size = self.context_parallel_size - if self.packing: - self.padding_free = True self.load = to_abspath(self.load, check_path_exist=True) - BaseArguments.__post_init__(self) - self.megatron_model_meta = get_megatron_model_meta(self.model_type) + super().__post_init__() if len(self.dataset) == 0 and len(self.cached_dataset) == 0: raise ValueError(f'self.dataset: {self.dataset}, self.cached_dataset: {self.cached_dataset}. ' 'Please input the training dataset.') self._init_save() - self.seq_length = self.seq_length or self.packing_length or self.max_length - if self.streaming: - self.dataloader_type = 'external' - if self.num_workers > 1: - self.num_workers = 1 - logger.info('Using streaming dataset, setting args.num_workers to 1.') - if self.load is None and self.no_initialization: + if self.tensorboard_dir is None and self.save is not None: + self.tensorboard_dir = f'{self.save}/runs' + self.tensorboard_dir = to_abspath(self.tensorboard_dir) + if self.load is None and self.no_initialization and not self.load_hf_checkpoint: raise ValueError('You did not pass `--load`, so you need to set `--no_initialization false` ' 'to allow the model to initialize weights properly.') if self.cached_dataset and self.context_parallel_size > 1: raise ValueError('`cached_dataset` does not support context parallelism.') + + def get_model_kwargs(self): + res = super().get_model_kwargs() + res['download_model'] = self.load_hf_checkpoint + return res diff --git a/swift/megatron/utils/convert.py b/swift/megatron/convert.py similarity index 78% rename from swift/megatron/utils/convert.py rename to swift/megatron/convert.py index aa3202f580..94d3d45a73 100644 --- a/swift/megatron/utils/convert.py +++ b/swift/megatron/convert.py @@ -6,19 +6,20 @@ from typing import Any, Dict import torch +import torch.distributed as dist import torch.nn as nn +from megatron.core import mpu from megatron.training import get_args from megatron.training.checkpointing import load_checkpoint from megatron.training.checkpointing import save_checkpoint as mg_save_checkpoint from megatron.training.initialize import initialize_megatron -from megatron.training.utils import get_ltor_masks_and_position_ids from swift.llm import (ExportArguments, HfConfigFactory, prepare_model_template, save_checkpoint, to_device, to_float_dtype) from swift.utils import get_logger, get_n_params_grads -from ..argument import MegatronArguments -from ..model import get_megatron_model_meta -from .patcher import patch_megatron_tokenizer, patch_torch_dist_shard +from .argument import MegatronArguments +from .model import get_megatron_model_meta +from .utils import convert_hf_config, forward_step_helper, get_padding_to, patch_torch_dist_shard logger = get_logger() @@ -29,7 +30,7 @@ def _test_params_sum(model): n_parameter = 0 for n, p in model.named_parameters(): n_parameter += 1 - sum_ = p.cuda().float().abs().sum().cpu().item() + sum_ = p.to(device='cuda', dtype=torch.float32).abs().sum().cpu().item() if sum_ == 0: zero_count += 1 logger.warning(f'n: {n}, sum: {sum_}') @@ -37,9 +38,10 @@ def _test_params_sum(model): logger.warning(f'n: {n}, sum: {sum_}') else: total_sum += sum_ - logger.info(f'n_parameter: {n_parameter}') - logger.info(f'total_sum: {total_sum}') - logger.info(f'zero_count: {zero_count}') + cond = mpu.get_data_parallel_rank() == 0 + logger.info_if(f'n_parameter: {n_parameter}', cond=cond) + logger.info_if(f'total_sum: {total_sum}', cond=cond) + logger.info_if(f'zero_count: {zero_count}', cond=cond) def _find_modules(model, recurse: bool = True, prefix='', ignore_modules=None): @@ -64,19 +66,23 @@ def _find_modules(model, recurse: bool = True, prefix='', ignore_modules=None): @contextmanager -def _model_cpu_forward_context(modules, torch_dtype=None, device=None, share_embedding: bool = False): - origin_torch_dtype = next(modules[0].parameters()).dtype +def _model_cpu_forward_context(modules, + torch_dtype=None, + compute_device=None, + share_embedding: bool = False, + target_device='cpu'): + origin_torch_dtype = next(modules[-1].parameters()).dtype def _to_cuda_hook(module, args): - if device is not None or torch_dtype is not None: - module.to(device=device, dtype=torch_dtype) + if compute_device is not None or torch_dtype is not None: + module.to(device=compute_device, dtype=torch_dtype) args = to_float_dtype(args, dtype=torch_dtype) return args def _to_cpu_hook(module, args, output): if share_embedding and module is modules[0]: return - module.to(device='cpu', dtype=origin_torch_dtype) + module.to(device=target_device, dtype=origin_torch_dtype) hooks = [] for module in modules: @@ -144,32 +150,34 @@ def get_examples(is_multimodal: bool) -> Dict[str, Any]: def test_convert_precision(hf_model, mg_model, template, torch_dtype=torch.float32): - hf_model.eval() - mg_model.eval() - _test_params_sum(hf_model) + template.set_mode('train') _test_params_sum(mg_model) - template.set_mode('train') - template.register_post_encode_hook([hf_model]) is_multimodal = template.model_meta.is_multimodal inputs = get_examples(is_multimodal) inputs = template.encode(inputs) - inputs = to_device(template.data_collator([inputs]), 'cuda') - - HfConfigFactory.set_model_config_attr(hf_model, 'use_cache', False) + hf_inputs = to_device(template.data_collator([inputs]), 'cuda') mg_language_model = mg_model.language_model if is_multimodal else mg_model share_embedding = mg_language_model.share_embeddings_and_output_weights - model_arch = hf_model.model_meta.model_arch - ignore_modules = (model_arch.vision_tower + model_arch.aligner) if is_multimodal else [] + if hf_model is not None: + hf_model.eval() + if dist.get_world_size() == 1: + _test_params_sum(hf_model) + template.register_post_encode_hook([hf_model]) + HfConfigFactory.set_model_config_attr(hf_model, 'use_cache', False) + model_arch = hf_model.model_meta.model_arch + ignore_modules = (model_arch.vision_tower + model_arch.aligner) if is_multimodal else [] + hf_modules = _find_modules(hf_model, ignore_modules=ignore_modules) + with torch.inference_mode(), _model_cpu_forward_context( + hf_modules, torch_dtype, share_embedding=share_embedding): + hf_inputs.pop('text_position_ids', None) + hf_logits = hf_model(**hf_inputs).logits + hf_logits = hf_logits.to('cuda') + hf_model.to('cpu') - hf_modules = _find_modules(hf_model, ignore_modules=ignore_modules) - with torch.inference_mode(), _model_cpu_forward_context(hf_modules, torch_dtype, share_embedding=share_embedding): - inputs.pop('text_position_ids', None) - hf_logits = hf_model(**inputs).logits - hf_model.to('cpu') - - input_ids = inputs['input_ids'] - attention_mask, _, position_ids = get_ltor_masks_and_position_ids(input_ids, -100, True, True, True) + args = get_args() + template.use_megatron = True + mg_inputs = to_device(template.data_collator([inputs], padding_to=get_padding_to(args)), 'cuda') packed_seq_params = None mg_torch_dtype = torch_dtype # thd @@ -179,24 +187,31 @@ def test_convert_precision(hf_model, mg_model, template, torch_dtype=torch.float # attention_mask = None mg_language_model.config.fp8 = None # compat fp8 mg_modules = _find_modules(mg_language_model, ignore_modules=['visual']) - kwargs = {k: v for k, v in inputs.items() if k not in ['input_ids', 'attention_mask', 'labels']} - if 'position_ids' not in kwargs: - kwargs['position_ids'] = position_ids + for key in ['labels', 'num_samples', 'attention_mask_2d', 'text_position_ids']: + mg_inputs.pop(key, None) + mg_inputs.update({'packed_seq_params': packed_seq_params}) + mg_device = next(mg_language_model.parameters()).device with torch.inference_mode(), _model_cpu_forward_context( - mg_modules, mg_torch_dtype, 'cuda', share_embedding=share_embedding): - mg_logits = mg_model( - input_ids=input_ids, attention_mask=attention_mask, packed_seq_params=packed_seq_params, **kwargs) - args = get_args() + mg_modules, mg_torch_dtype, 'cuda', share_embedding=share_embedding, target_device=mg_device): + # TODO: test pp tie_weights + mg_logits = forward_step_helper(mg_model, mg_inputs, dtype=mg_torch_dtype) + if args.tensor_model_parallel_size > 1: + from megatron.core.tensor_parallel.mappings import gather_from_tensor_model_parallel_region + if mg_logits is not None: + mg_logits = gather_from_tensor_model_parallel_region(mg_logits) + if hf_model is None: + return if args.task_type == 'seq_cls': mg_logits = mg_logits[:, -1] mean_diff = (mg_logits - hf_logits).abs().mean().item() max_diff = (mg_logits - hf_logits).abs().max().item() print(f'mean_diff: {mean_diff}, max_diff: {max_diff}') else: + mg_logits = mg_logits[:, :hf_logits.shape[1]] token_mean_diff = (mg_logits - hf_logits).abs().mean(dim=-1) mean_diff = token_mean_diff.mean().item() max_diff = (mg_logits - hf_logits).abs().max().item() - loss_mask = (torch.roll(inputs['labels'], -1) != -100) + loss_mask = (torch.roll(hf_inputs['labels'], -1) != -100) mean_diff_with_loss = token_mean_diff[loss_mask].mean().item() max_diff_with_loss = (mg_logits - hf_logits)[loss_mask].abs().max().item() print(f'token_mean_diff: {token_mean_diff}') @@ -238,7 +253,7 @@ def convert_hf2mcore(args: ExportArguments) -> None: megatron_model_meta = get_megatron_model_meta(args.model_type) assert megatron_model_meta is not None, f'Model: {args.model} is not supported.' - kwargs = megatron_model_meta.convert_hf_config(processor.model_info.config) + kwargs = convert_hf_config(processor.model_info.config) logger.info(f'megatron_config: {kwargs}') _check_megatron_kwargs(kwargs) current_convert_kwargs = convert_kwargs.copy() @@ -246,7 +261,6 @@ def convert_hf2mcore(args: ExportArguments) -> None: current_convert_kwargs['moe_grouped_gemm'] = True megatron_args = MegatronArguments( **kwargs, **current_convert_kwargs, save=args.output_dir, torch_dtype=args.torch_dtype) - patch_megatron_tokenizer(processor) extra_args = megatron_args.parse_to_megatron() extra_args['model_info'] = args.model_info extra_args['model_meta'] = args.model_meta @@ -256,11 +270,11 @@ def convert_hf2mcore(args: ExportArguments) -> None: mg_model = megatron_model_meta.model_provider() logger.info('Megatron model created successfully.') - megatron_model_meta.convert_hf2mcore(hf_model, mg_model) + bridge = megatron_model_meta.bridge_cls() + bridge.load_weights(mg_model, args.model_info.model_dir) + logger.info('Successfully transferred HF model weights to MG model.') if args.test_convert_precision: test_convert_precision(hf_model, mg_model, template, args.test_convert_dtype) - del hf_model - logger.info('Successfully transferred HF model weights to MG model.') args.save_args() logger.info('Saving the model...') mg_save_checkpoint(1, [mg_model], None, None, 0) @@ -274,7 +288,7 @@ def convert_mcore2hf(args: ExportArguments) -> None: megatron_model_meta = get_megatron_model_meta(args.model_type) assert megatron_model_meta is not None, f'Model: {args.model} is not supported.' - kwargs = megatron_model_meta.convert_hf_config(processor.model_info.config) + kwargs = convert_hf_config(processor.model_info.config) logger.info(f'megatron_config: {kwargs}') _check_megatron_kwargs(kwargs) current_convert_kwargs = convert_kwargs.copy() @@ -291,7 +305,6 @@ def convert_mcore2hf(args: ExportArguments) -> None: **current_convert_kwargs, save=args.output_dir if args.to_mcore else None, torch_dtype=args.torch_dtype) - patch_megatron_tokenizer(processor) extra_args = megatron_args.parse_to_megatron() extra_args['model_info'] = args.model_info extra_args['model_meta'] = args.model_meta @@ -311,23 +324,13 @@ def convert_mcore2hf(args: ExportArguments) -> None: mg_model = peft_model.merge_and_unload() logger.info('Megatron model created successfully.') if args.to_hf: - hf_model, template = prepare_model_template(args, patch_offload=not args.test_convert_precision) - megatron_model_meta.convert_mcore2hf(hf_model, mg_model) + bridge = megatron_model_meta.bridge_cls() + logger.info('Converting weights and saving the model...') + bridge.save_weights([mg_model], args.output_dir) + logger.info(f'Successfully saved HF model weights in `{args.output_dir}`.') if args.test_convert_precision: + hf_model, template = prepare_model_template(args, model=args.output_dir) test_convert_precision(hf_model, mg_model, template, args.test_convert_dtype) - del mg_model - logger.info('Successfully transferred MG model weights to HF model.') - ckpt_dir = megatron_args.load if megatron_args.adapter_load is None else megatron_args.adapter_load - logger.info('Saving the model...') - save_checkpoint( - hf_model, - processor, - args.output_dir, - safe_serialization=args.safe_serialization, - model_dirs=[ckpt_dir, args.model_dir], - max_shard_size=args.max_shard_size, - additional_saved_files=hf_model.model_meta.additional_saved_files) - logger.info(f'Successfully saved HF model weights in `{args.output_dir}`.') elif args.to_mcore: if args.thread_count is None: checkpoint_size = sum(get_n_params_grads(mg_model)[0]) * torch.finfo(args.torch_dtype).bits // 8e9 diff --git a/swift/megatron/export/__init__.py b/swift/megatron/export/__init__.py new file mode 100644 index 0000000000..97590932f0 --- /dev/null +++ b/swift/megatron/export/__init__.py @@ -0,0 +1 @@ +from .export import megatron_export_main diff --git a/swift/megatron/export/export.py b/swift/megatron/export/export.py new file mode 100644 index 0000000000..6ff56a52b0 --- /dev/null +++ b/swift/megatron/export/export.py @@ -0,0 +1,101 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import List, Optional, Union + +import torch.distributed as dist +from megatron.core import mpu +from megatron.training import initialize_megatron +from megatron.training.checkpointing import load_checkpoint +from megatron.training.checkpointing import save_checkpoint as mg_save_checkpoint + +from swift.llm import SwiftPipeline, prepare_model_template +from swift.utils import disable_safe_ddp_context_use_barrier, get_logger, is_last_rank +from ..argument import MegatronExportArguments +from ..convert import test_convert_precision +from ..utils import adapter_state_dict_context, patch_load_base_checkpoint, prepare_mcore_model + +logger = get_logger() + + +class MegatronExport(SwiftPipeline): + args_class = MegatronExportArguments + args: args_class + + def run(self): + args = self.args + if args.to_hf: + self.convert_mcore2hf() + elif args.to_mcore: + self.convert_hf2mcore() + + def convert_mcore2hf(self) -> None: + args = self.args + _, template = prepare_model_template(args, load_model=False) + self.processor = template.processor + args.init_model_args(self.tokenizer, self.processor.model_info.config) + megatron_model_meta = args.megatron_model_meta + extra_args_provider = megatron_model_meta.extra_args_provider + initialize_megatron(extra_args_provider=extra_args_provider, args_defaults=args.extra_args) + + pre_process = mpu.is_pipeline_first_stage() + post_process = mpu.is_pipeline_last_stage() + mg_model = megatron_model_meta.model_provider(pre_process=pre_process, post_process=post_process) + with patch_load_base_checkpoint(): + load_checkpoint([mg_model], None, None, strict=True) + if args.adapter_load is not None: + prepare_mcore_model(mg_model) + with adapter_state_dict_context(): + load_checkpoint([mg_model], None, None, load_arg='adapter_load', strict=False) + logger.info('Converting weights and saving the model...') + bridge = megatron_model_meta.bridge_cls() + save_peft_format = args.train_type == 'lora' and not args.merge_lora + bridge.save_weights([mg_model], args.save, is_peft_format=save_peft_format) + if is_last_rank(): + args_path = os.path.join(os.path.dirname(args.save), 'args.json') + if os.path.exists(args_path): + shutil.copy(args_path, os.path.join(args.save, 'args.json')) + if args.test_convert_precision: + with disable_safe_ddp_context_use_barrier(): + if save_peft_format: + kwargs = {'adapters': [args.save]} + else: + kwargs - {'model': args.save} + hf_model = prepare_model_template(args, device_map='cpu', **kwargs)[0] if is_last_rank() else None + test_convert_precision(hf_model, mg_model, template, args.test_convert_dtype) + dist.barrier() + + def convert_hf2mcore(self) -> None: + args = self.args + _, template = prepare_model_template(args, load_model=False) + self.processor = template.processor + args.init_model_args(self.tokenizer, self.processor.model_info.config) + megatron_model_meta = args.megatron_model_meta + extra_args_provider = megatron_model_meta.extra_args_provider + initialize_megatron(extra_args_provider=extra_args_provider, args_defaults=args.extra_args) + + pre_process = mpu.is_pipeline_first_stage() + post_process = mpu.is_pipeline_last_stage() + mg_model = megatron_model_meta.model_provider(pre_process=pre_process, post_process=post_process) + logger.info('Megatron model created successfully.') + bridge = megatron_model_meta.bridge_cls() + bridge.load_weights(mg_model, args.model_info.model_dir) + dist.barrier() + if args.adapters: + prepare_mcore_model(mg_model) + assert len(args.adapters) == 1, 'Currently only support one adapter' + bridge.load_weights(mg_model, args.adapters[0], is_peft_format=True) + logger.info('Successfully transferred HF model weights to MG model.') + if args.test_convert_precision: + with disable_safe_ddp_context_use_barrier(): + hf_model = prepare_model_template(args, device_map='cpu')[0] if is_last_rank() else None + test_convert_precision(hf_model, mg_model, template, args.test_convert_dtype) + dist.barrier() + if is_last_rank(): + args.save_args(args.save) + logger.info('Saving the model...') + with adapter_state_dict_context(): + mg_save_checkpoint(1, [mg_model], None, None, 0) + logger.info_if(f'Successfully saved Megatron model weights in `{args.save}`.', cond=is_last_rank()) + + +def megatron_export_main(args: Optional[Union[List[str], MegatronExportArguments]] = None): + return MegatronExport(args).main() diff --git a/swift/megatron/init.py b/swift/megatron/init.py index e5dfcd7725..a1c6272d18 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -379,6 +379,15 @@ def sharded_state_dict( TEGroupedLinear.sharded_state_dict = sharded_state_dict +def _patch_megatron_tokenizer(): + from megatron.training import global_vars + + def build_tokenizer(args): + return 'dummy_tokenizer' + + global_vars.build_tokenizer = build_tokenizer + + def _patch_peft_ModulesToSaveWrapper(): if version.parse(peft.__version__) >= version.parse('0.16'): from peft.utils import other as peft_module @@ -509,6 +518,16 @@ def _worker(plan_shard): FileSystemReader.read_data = read_data +def _patch_validate_non_overlapping_shards_metadata(): + # too slow + from torch.distributed._shard.sharded_tensor import api + + def validate_non_overlapping_shards_metadata(*args, **kwargs): + pass + + api.validate_non_overlapping_shards_metadata = validate_non_overlapping_shards_metadata + + def _patch_TELinear(): from megatron.core.extensions.transformer_engine import TELinear @@ -664,6 +683,7 @@ def _patch_megatron(): _patch_compile_helpers() _patch_build_train_valid_test_datasets() _patch_mrope() + _patch_megatron_tokenizer() logging.root.setLevel(logging_level) # revert logger level from swift.megatron import tuners # patch lora try: @@ -671,6 +691,11 @@ def _patch_megatron(): logger.info('Patch FileSystemReader successfully applied.') except Exception: pass + try: + _patch_validate_non_overlapping_shards_metadata() + except Exception: + logger.warning('Patch validate_non_overlapping_shards_metadata failed.') + pass try: _patch_peft_BaseTuner() _patch_peft_ModulesToSaveWrapper() diff --git a/swift/megatron/model/__init__.py b/swift/megatron/model/__init__.py index 3c882c9864..35b8777147 100644 --- a/swift/megatron/model/__init__.py +++ b/swift/megatron/model/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from . import gpt, mm_gpt +# from . import gpt, mm_gpt +from . import gpt from .constant import MegatronModelType from .register import MegatronModelMeta, get_megatron_model_meta, register_megatron_model diff --git a/swift/megatron/model/config.py b/swift/megatron/model/config.py deleted file mode 100644 index 7a68537efc..0000000000 --- a/swift/megatron/model/config.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -from typing import Any, Dict - -from swift.utils import get_logger - -logger = get_logger() -config_mapping = { - 'num_layers': ['num_hidden_layers'], - 'hidden_size': ['hidden_size'], - 'ffn_hidden_size': ['intermediate_size'], - 'num_attention_heads': ['num_attention_heads'], - 'num_query_groups': ['num_key_value_heads'], - 'max_position_embeddings': ['max_position_embeddings'], - 'norm_epsilon': ['rms_norm_eps'], - 'rotary_base': ['rope_theta'], - 'padded_vocab_size': ['vocab_size'], - 'attention_dropout': ['attention_dropout'], - 'untie_embeddings_and_output_weights': ['tie_word_embeddings'], - 'swiglu': ['hidden_act'], - 'add_qkv_bias': ['attention_bias', 'qkv_bias', 'use_bias'], - 'disable_bias_linear': ['mlp_bias'], - 'kv_channels': ['head_dim', 'v_head_dim'], - 'architectures': ['architectures'], - # moe - 'moe_ffn_hidden_size': ['moe_intermediate_size'], - 'moe_shared_expert_intermediate_size': ['shared_expert_intermediate_size'], - 'moe_router_topk': ['num_experts_per_tok', 'n_group', 'moe_topk', 'moe_k'], - 'num_experts': ['num_experts', 'n_routed_experts', 'moe_num_experts'], - 'moe_router_pre_softmax': ['norm_topk_prob'], - # deepseek - 'q_lora_rank': ['q_lora_rank'], - 'kv_lora_rank': ['kv_lora_rank'], - 'moe_router_score_function': ['scoring_func'], - 'qk_head_dim': ['qk_nope_head_dim'], - 'qk_pos_emb_head_dim': ['qk_rope_head_dim'], - 'moe_router_topk_scaling_factor': ['routed_scaling_factor'], - 'qk_layernorm': ['use_qk_norm'], - # qwen3_next - 'linear_num_value_heads': ['linear_num_value_heads'], - 'linear_num_key_heads': ['linear_num_key_heads'], - 'linear_key_head_dim': ['linear_key_head_dim'], - 'linear_value_head_dim': ['linear_value_head_dim'], - 'linear_conv_kernel_dim': ['linear_conv_kernel_dim'], - 'full_attention_interval': ['full_attention_interval'], - # other - 'original_max_position_embeddings': ['original_max_position_embeddings'], - 'partial_rotary_factor': ['partial_rotary_factor'], - 'first_k_dense_replace': ['first_k_dense_replace', 'moe_layer_start_index'], - 'n_shared_experts': ['n_shared_experts', 'num_shared_expert', 'moe_num_shared_experts'], -} - - -def convert_hf_config(config, _internal_call=False) -> Dict[str, Any]: - megatron_config = {} - for k, hf_keys in config_mapping.items(): - for hf_k in hf_keys: - if hasattr(config, hf_k): - hf_v = getattr(config, hf_k) - if hf_v is None: - continue - if k == 'rotary_base': - megatron_config[k] = int(hf_v) - elif k in {'untie_embeddings_and_output_weights', 'disable_bias_linear', 'moe_router_pre_softmax'}: - megatron_config[k] = not hf_v - elif k == 'swiglu': - if hf_v == 'silu': - megatron_config[k] = True - else: - if k == 'kv_lora_rank': - megatron_config['multi_latent_attention'] = True - elif k == 'architectures': - if _internal_call: - k = 'llm_architectures' - megatron_config[k] = hf_v - break - for key in ['text_config', 'llm_config', 'thinker_config']: - if hasattr(config, key): - megatron_config.update(convert_hf_config(getattr(config, key), _internal_call=True)) - # compat llama3 - if getattr(config, 'rope_scaling', None) is not None: - if isinstance(config.rope_scaling, int): - megatron_config['rope_scaling'] = {'factor': config.rope_scaling, 'type': 'linear'}, - elif isinstance(config.rope_scaling, dict): - megatron_config['rope_scaling'] = config.rope_scaling - return megatron_config diff --git a/swift/megatron/model/gpt/__init__.py b/swift/megatron/model/gpt/__init__.py index f3eb68e0cd..9dd370ccfd 100644 --- a/swift/megatron/model/gpt/__init__.py +++ b/swift/megatron/model/gpt/__init__.py @@ -4,9 +4,6 @@ from ..gpt_model import GPTModel from ..register import MegatronModelMeta, register_megatron_model from . import qwen3_next -from .config import convert_gpt_hf_config -from .hf2mcore import convert_hf2mcore -from .mcore2hf import convert_mcore2hf register_megatron_model( MegatronModelMeta( @@ -58,8 +55,4 @@ ModelType.deepseek_v3_1, ModelType.ernie_thinking, ], - model_cls=GPTModel, - convert_hf_config=convert_gpt_hf_config, - convert_mcore2hf=convert_mcore2hf, - convert_hf2mcore=convert_hf2mcore, )) diff --git a/swift/megatron/model/gpt/hf2mcore.py b/swift/megatron/model/gpt/hf2mcore.py deleted file mode 100644 index cc8960163b..0000000000 --- a/swift/megatron/model/gpt/hf2mcore.py +++ /dev/null @@ -1,127 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -from typing import Optional - -import torch -from megatron.training import get_args -from torch import nn - - -def set_mla_attn_state(args, mg_attn, hf_attn): - mg_attn.linear_proj.weight.data.copy_(hf_attn.o_proj.weight) - if args.q_lora_rank is None: - mg_attn.linear_q_proj.weight.data.copy_(hf_attn.q_proj.weight) - else: - mg_attn.linear_q_down_proj.weight.data.copy_(hf_attn.q_a_proj.weight) - mg_attn.linear_q_up_proj.weight.data.copy_(hf_attn.q_b_proj.weight) - mg_attn.linear_kv_down_proj.weight.data.copy_(hf_attn.kv_a_proj_with_mqa.weight) - mg_attn.linear_kv_up_proj.weight.data.copy_(hf_attn.kv_b_proj.weight) - if args.qk_layernorm: - mg_attn.linear_kv_up_proj.layer_norm_weight.data.copy_(hf_attn.kv_a_layernorm.weight) - - -def set_attn_state(args, mg_attn, hf_attn): - num_query_groups = (args.num_query_groups if args.group_query_attention else args.num_attention_heads) - - # Copy weights - mg_attn.linear_qkv.weight.data.copy_( - torch.cat([ - hf_attn.q_proj.weight.reshape((num_query_groups, -1, args.hidden_size)), - hf_attn.k_proj.weight.reshape((num_query_groups, -1, args.hidden_size)), - hf_attn.v_proj.weight.reshape((num_query_groups, -1, args.hidden_size)), - ], - dim=1).reshape((-1, args.hidden_size))) - mg_attn.linear_proj.weight.data.copy_(hf_attn.o_proj.weight) - - # Copy bias - if args.add_qkv_bias: - mg_attn.linear_qkv.bias.data.copy_( - torch.cat([ - hf_attn.q_proj.bias.reshape((num_query_groups, -1)), - hf_attn.k_proj.bias.reshape((num_query_groups, -1)), - hf_attn.v_proj.bias.reshape((num_query_groups, -1)), - ], - dim=1).reshape(-1)) - if args.qk_layernorm: - q_norm = hf_attn.query_layernorm if hasattr(hf_attn, 'query_layernorm') else hf_attn.q_norm - k_norm = hf_attn.key_layernorm if hasattr(hf_attn, 'key_layernorm') else hf_attn.k_norm - mg_attn.q_layernorm.weight.data.copy_(q_norm.weight) - mg_attn.k_layernorm.weight.data.copy_(k_norm.weight) - - -def _set_mlp_state(mg_mlp, hf_mlp, group_idx: Optional[int] = None): - hf_grouped = not isinstance(hf_mlp.down_proj, nn.Module) - if group_idx is None: - linear_fc1_weight = mg_mlp.linear_fc1.weight - linear_fc2_weight = mg_mlp.linear_fc2.weight - else: - linear_fc1_weight = getattr(mg_mlp.linear_fc1, f'weight{group_idx}') - linear_fc2_weight = getattr(mg_mlp.linear_fc2, f'weight{group_idx}') - if hf_grouped: - linear_fc1_weight.data.copy_(hf_mlp.gate_up_proj[group_idx].t()) - linear_fc2_weight.data.copy_(hf_mlp.down_proj[group_idx].t()) - else: - if hasattr(hf_mlp, 'gate_up_proj'): - linear_fc1_weight.data.copy_(hf_mlp.gate_up_proj.weight) - else: - linear_fc1_weight.data.copy_(torch.cat([hf_mlp.gate_proj.weight, hf_mlp.up_proj.weight], dim=0)) - linear_fc2_weight.data.copy_(hf_mlp.down_proj.weight) - - -def _set_moe_state(args, mg_mlp, hf_mlp): - hf_gate = hf_mlp.gate - if hasattr(hf_gate, 'wg'): - hf_gate = hf_gate.wg - mg_mlp.router.weight.data.copy_(hf_gate.weight) - if args.moe_router_enable_expert_bias: - mg_mlp.router.expert_bias.data.copy_(hf_gate.e_score_correction_bias) - if mg_mlp.shared_experts is not None: - if hasattr(hf_mlp, 'shared_experts'): - hf_shared_expert = hf_mlp.shared_experts - elif hasattr(hf_mlp, 'shared_mlp'): - hf_shared_expert = hf_mlp.shared_mlp - else: - hf_shared_expert = hf_mlp.shared_expert - _set_mlp_state(mg_mlp.shared_experts, hf_shared_expert) - if mg_mlp.shared_experts.gate_weight is not None: - mg_mlp.shared_experts.gate_weight.data.copy_(hf_mlp.shared_expert_gate.weight) - for expert_idx in range(args.num_experts): - hf_expert = hf_mlp.experts - if hasattr(hf_expert, '__len__'): - hf_expert = hf_expert[expert_idx] - _set_mlp_state(mg_mlp.experts, hf_expert, group_idx=expert_idx) - - -def set_mlp_state(args, mg_mlp, hf_mlp): - if 'moe' in mg_mlp.__class__.__name__.lower(): - _set_moe_state(args, mg_mlp, hf_mlp) - else: - _set_mlp_state(mg_mlp, hf_mlp) - - -def set_layer_state(args, mg_model, hf_model, layer_idx): - mg_layer = mg_model.decoder.layers[layer_idx] - hf_layer = hf_model.layers[layer_idx] - if args.multi_latent_attention: - set_mla_attn_state(args, mg_layer.self_attention, hf_layer.self_attn) - mg_layer.input_layernorm.weight.data.copy_(hf_layer.input_layernorm.weight) - else: - set_attn_state(args, mg_layer.self_attention, hf_layer.self_attn) - mg_layer.self_attention.linear_qkv.layer_norm_weight.data.copy_(hf_layer.input_layernorm.weight) - - set_mlp_state(args, mg_layer.mlp, hf_layer.mlp) - - post_attention_layernorm_weight = hf_layer.post_attention_layernorm.weight - if 'moe' in mg_layer.mlp.__class__.__name__.lower(): - mg_layer.pre_mlp_layernorm.weight.data.copy_(post_attention_layernorm_weight) - else: - mg_layer.mlp.linear_fc1.layer_norm_weight.data.copy_(post_attention_layernorm_weight) - - -def convert_hf2mcore(hf_model, mg_model): - args = get_args() - mg_model.embedding.word_embeddings.weight.data.copy_(hf_model.model.embed_tokens.weight) - if args.untie_embeddings_and_output_weights: - mg_model.output_layer.weight.data.copy_(hf_model.lm_head.weight) - mg_model.decoder.final_layernorm.weight.data.copy_(hf_model.model.norm.weight) - for layer_idx in range(args.num_layers): - set_layer_state(args, mg_model, hf_model.model, layer_idx) diff --git a/swift/megatron/model/gpt/mcore2hf.py b/swift/megatron/model/gpt/mcore2hf.py deleted file mode 100644 index eac8023801..0000000000 --- a/swift/megatron/model/gpt/mcore2hf.py +++ /dev/null @@ -1,127 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -from typing import Optional - -from megatron.training import get_args -from torch import nn - - -def set_mla_attn_state(args, mg_attn, hf_attn): - hf_attn.o_proj.weight.data.copy_(mg_attn.linear_proj.weight) - if args.q_lora_rank is None: - hf_attn.q_proj.weight.data.copy_(mg_attn.linear_q_proj.weight) - else: - hf_attn.q_a_proj.weight.data.copy_(mg_attn.linear_q_down_proj.weight) - hf_attn.q_b_proj.weight.data.copy_(mg_attn.linear_q_up_proj.weight) - hf_attn.kv_a_proj_with_mqa.weight.data.copy_(mg_attn.linear_kv_down_proj.weight) - hf_attn.kv_b_proj.weight.data.copy_(mg_attn.linear_kv_up_proj.weight) - if args.qk_layernorm: - hf_attn.kv_a_layernorm.weight.data.copy_(mg_attn.linear_kv_up_proj.layer_norm_weight) - - -def set_attn_state(args, mg_attn, hf_attn): - num_query_groups = (args.num_query_groups if args.group_query_attention else args.num_attention_heads) - # Copy weights - mg_attn_weight = mg_attn.linear_qkv.weight.reshape((num_query_groups, -1, args.hidden_size)) - q_dim, kv_dim = hf_attn.q_proj.weight.shape[0] // num_query_groups, hf_attn.k_proj.weight.shape[ - 0] // num_query_groups - hf_attn.q_proj.weight.data.copy_(mg_attn_weight[:, :q_dim, :].reshape(-1, args.hidden_size)) - hf_attn.k_proj.weight.data.copy_(mg_attn_weight[:, q_dim:-kv_dim, :].reshape(-1, args.hidden_size)) - hf_attn.v_proj.weight.data.copy_(mg_attn_weight[:, -kv_dim:, :].reshape(-1, args.hidden_size)) - hf_attn.o_proj.weight.data.copy_(mg_attn.linear_proj.weight) - - # Copy bias - if args.add_qkv_bias: - mg_attn_bias = mg_attn.linear_qkv.bias.reshape((num_query_groups, -1)) - hf_attn.q_proj.bias.data.copy_(mg_attn_bias[:, :q_dim].reshape(-1)) - hf_attn.k_proj.bias.data.copy_(mg_attn_bias[:, q_dim:-kv_dim].reshape(-1)) - hf_attn.v_proj.bias.data.copy_(mg_attn_bias[:, -kv_dim:].reshape(-1)) - - if args.qk_layernorm: - q_norm = hf_attn.query_layernorm if hasattr(hf_attn, 'query_layernorm') else hf_attn.q_norm - k_norm = hf_attn.key_layernorm if hasattr(hf_attn, 'key_layernorm') else hf_attn.k_norm - q_norm.weight.data.copy_(mg_attn.q_layernorm.weight) - k_norm.weight.data.copy_(mg_attn.k_layernorm.weight) - - -def _set_moe_state(args, mg_mlp, hf_mlp): - hf_gate = hf_mlp.gate - if hasattr(hf_gate, 'wg'): - hf_gate = hf_gate.wg - hf_gate.weight.data.copy_(mg_mlp.router.weight) - if args.moe_router_enable_expert_bias: - hf_gate.e_score_correction_bias.data.copy_(mg_mlp.router.expert_bias) - if mg_mlp.shared_experts is not None: - if hasattr(hf_mlp, 'shared_experts'): - hf_shared_expert = hf_mlp.shared_experts - elif hasattr(hf_mlp, 'shared_mlp'): - hf_shared_expert = hf_mlp.shared_mlp - else: - hf_shared_expert = hf_mlp.shared_expert - _set_mlp_state(mg_mlp.shared_experts, hf_shared_expert) - if mg_mlp.shared_experts.gate_weight is not None: - hf_mlp.shared_expert_gate.weight.data.copy_(mg_mlp.shared_experts.gate_weight) - for expert_idx in range(args.num_experts): - hf_expert = hf_mlp.experts - if hasattr(hf_expert, '__len__'): - hf_expert = hf_expert[expert_idx] - _set_mlp_state(mg_mlp.experts, hf_expert, group_idx=expert_idx) - - -def _set_mlp_state(mg_mlp, hf_mlp, group_idx: Optional[int] = None): - hf_grouped = not isinstance(hf_mlp.down_proj, nn.Module) - if group_idx is None: - linear_fc1_weight = mg_mlp.linear_fc1.weight - linear_fc2_weight = mg_mlp.linear_fc2.weight - else: - linear_fc1_weight = getattr(mg_mlp.linear_fc1, f'weight{group_idx}') - linear_fc2_weight = getattr(mg_mlp.linear_fc2, f'weight{group_idx}') - - if hf_grouped: - hf_mlp.gate_up_proj.data[group_idx] = linear_fc1_weight.t() - hf_mlp.down_proj.data[group_idx] = linear_fc2_weight.t() - else: - if hasattr(hf_mlp, 'gate_up_proj'): - hf_mlp.gate_up_proj.weight.data.copy_(linear_fc1_weight) - else: - ffn_hidden_size = hf_mlp.gate_proj.weight.shape[0] - hf_mlp.gate_proj.weight.data.copy_(linear_fc1_weight[:ffn_hidden_size]) - hf_mlp.up_proj.weight.data.copy_(linear_fc1_weight[ffn_hidden_size:]) - hf_mlp.down_proj.weight.data.copy_(linear_fc2_weight) - - -def set_mlp_state(args, mg_mlp, hf_mlp): - if 'moe' in mg_mlp.__class__.__name__.lower(): - _set_moe_state(args, mg_mlp, hf_mlp) - else: - _set_mlp_state(mg_mlp, hf_mlp) - - -def set_layer_state(args, mg_model, hf_model, layer_idx): - mg_layer = mg_model.decoder.layers[layer_idx] - hf_layer = hf_model.layers[layer_idx] - - if args.multi_latent_attention: - set_mla_attn_state(args, mg_layer.self_attention, hf_layer.self_attn) - hf_layer.input_layernorm.weight.data.copy_(mg_layer.input_layernorm.weight) - else: - set_attn_state(args, mg_layer.self_attention, hf_layer.self_attn) - hf_layer.input_layernorm.weight.data.copy_(mg_layer.self_attention.linear_qkv.layer_norm_weight) - - set_mlp_state(args, mg_layer.mlp, hf_layer.mlp) - - post_attention_layernorm_weight = hf_layer.post_attention_layernorm.weight - if 'moe' in mg_layer.mlp.__class__.__name__.lower(): - post_attention_layernorm_weight.data.copy_(mg_layer.pre_mlp_layernorm.weight) - else: - post_attention_layernorm_weight.data.copy_(mg_layer.mlp.linear_fc1.layer_norm_weight) - - -def convert_mcore2hf(hf_model, mg_model): - args = get_args() - hf_model.model.embed_tokens.weight.data.copy_(mg_model.embedding.word_embeddings.weight) - if args.untie_embeddings_and_output_weights: - lm_head_weight = hf_model.score.weight if args.task_type == 'seq_cls' else hf_model.lm_head.weight - lm_head_weight.data.copy_(mg_model.output_layer.weight) - hf_model.model.norm.weight.data.copy_(mg_model.decoder.final_layernorm.weight) - for layer_idx in range(args.num_layers): - set_layer_state(args, mg_model, hf_model.model, layer_idx) diff --git a/swift/megatron/model/gpt/qwen3_next.py b/swift/megatron/model/gpt/qwen3_next.py index 1d291b6a4c..2dd9026dd9 100644 --- a/swift/megatron/model/gpt/qwen3_next.py +++ b/swift/megatron/model/gpt/qwen3_next.py @@ -3,7 +3,6 @@ from typing import Optional, Tuple, Union import torch -from megatron.core import mpu from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, TENorm from megatron.core.inference.contexts import BaseInferenceContext from megatron.core.models.common.embeddings.rope_utils import apply_rotary_pos_emb @@ -22,9 +21,8 @@ from swift.llm import ModelType from swift.utils import get_logger from ..constant import MegatronModelType -from ..gpt_model import GPTModel +from ..gpt_bridge import GPTBridge from ..register import MegatronModelMeta, register_megatron_model -from .config import convert_gpt_hf_config try: from flashattn_hopper.flash_attn_interface import _flash_attn_forward @@ -473,61 +471,26 @@ def get_qwen3_next_transformer_layer_spec(config, vp_stage=None): return block_spec -def convert_mcore2hf_qwen3_next(hf_model, mg_model): - from .mcore2hf import set_mlp_state, set_attn_state - args = get_args() - hf_model.model.embed_tokens.weight.data.copy_(mg_model.embedding.word_embeddings.weight) - if args.untie_embeddings_and_output_weights: - lm_head_weight = hf_model.score.weight if args.task_type == 'seq_cls' else hf_model.lm_head.weight - lm_head_weight.data.copy_(mg_model.output_layer.weight) - hf_model.model.norm.weight.data.copy_(mg_model.decoder.final_layernorm.weight - 1) - for layer_idx in range(args.num_layers): - layer_type = args.layer_types[layer_idx] - mg_layer = mg_model.decoder.layers[layer_idx] - hf_layer = hf_model.model.layers[layer_idx] - mg_attn = mg_layer.self_attention - - if layer_type == 'linear_attention': - hf_layer.linear_attn.load_state_dict(mg_attn.state_dict(), strict=False) - hf_layer.input_layernorm.weight.data.copy_(mg_layer.input_layernorm.weight - 1) - elif layer_type == 'full_attention': - hf_attn = hf_layer.self_attn - set_attn_state(args, mg_attn, hf_attn) - hf_layer.input_layernorm.weight.data.copy_(mg_attn.linear_qkv.layer_norm_weight - 1) - if args.qk_layernorm: - hf_attn.q_norm.weight.data.copy_(mg_attn.q_layernorm.weight - 1) - hf_attn.k_norm.weight.data.copy_(mg_attn.k_layernorm.weight - 1) - - set_mlp_state(args, mg_layer.mlp, hf_layer.mlp) - hf_layer.post_attention_layernorm.weight.data.copy_(mg_layer.pre_mlp_layernorm.weight - 1) +class Qwen3NextBridge(GPTBridge): + def _set_state_dict(self, state_dict, res_state_dict, hf_key: str, mg_key: str, reverse: bool, offset: float = 0): + if hf_key in { + 'model.norm.weight', 'q_norm.weight', 'k_norm.weight', 'input_layernorm.weight', + 'post_attention_layernorm.weight' + }: + offset = -1 if reverse else 1 + else: + assert 'norm' not in hf_key, f'hf_key: {hf_key}' # just check + return super()._set_state_dict(state_dict, res_state_dict, hf_key, mg_key, reverse, offset) -def convert_hf2mcore_qwen3_next(hf_model, mg_model): - from .hf2mcore import set_mlp_state, set_attn_state - args = get_args() - mg_model.embedding.word_embeddings.weight.data.copy_(hf_model.model.embed_tokens.weight) - if args.untie_embeddings_and_output_weights: - mg_model.output_layer.weight.data.copy_(hf_model.lm_head.weight) - mg_model.decoder.final_layernorm.weight.data.copy_(hf_model.model.norm.weight + 1) - for layer_idx in range(args.num_layers): - layer_type = args.layer_types[layer_idx] - mg_layer = mg_model.decoder.layers[layer_idx] - hf_layer = hf_model.model.layers[layer_idx] - mg_attn = mg_layer.self_attention - + def _set_layer_attn(self, state_dict, layer_idx: int, reverse: bool): + layer_type = self.args.layer_types[layer_idx] if layer_type == 'linear_attention': - mg_attn.load_state_dict(hf_layer.linear_attn.state_dict(), strict=False) - mg_layer.input_layernorm.weight.data.copy_(hf_layer.input_layernorm.weight + 1) + res = self._replace_prefix(state_dict, 'linear_attn.', 'self_attention.', reverse) + self._set_state_dict(state_dict, res, 'input_layernorm.weight', 'input_layernorm.weight', reverse) elif layer_type == 'full_attention': - hf_attn = hf_layer.self_attn - set_attn_state(args, mg_attn, hf_attn) - mg_attn.linear_qkv.layer_norm_weight.data.copy_(hf_layer.input_layernorm.weight + 1) - if args.qk_layernorm: - mg_attn.q_layernorm.weight.data.copy_(hf_attn.q_norm.weight + 1) - mg_attn.k_layernorm.weight.data.copy_(hf_attn.k_norm.weight + 1) - - set_mlp_state(args, mg_layer.mlp, hf_layer.mlp) - mg_layer.pre_mlp_layernorm.weight.data.copy_(hf_layer.post_attention_layernorm.weight + 1) + res = super()._set_layer_attn(state_dict, layer_idx, reverse) + return res register_megatron_model( @@ -537,9 +500,6 @@ def convert_hf2mcore_qwen3_next(hf_model, mg_model): ModelType.qwen3_next, ModelType.qwen3_next_thinking, ], - model_cls=GPTModel, - convert_hf_config=convert_gpt_hf_config, get_transformer_layer_spec=get_qwen3_next_transformer_layer_spec, - convert_mcore2hf=convert_mcore2hf_qwen3_next, - convert_hf2mcore=convert_hf2mcore_qwen3_next, + bridge_cls=Qwen3NextBridge, )) diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py new file mode 100644 index 0000000000..a599d6e639 --- /dev/null +++ b/swift/megatron/model/gpt_bridge.py @@ -0,0 +1,715 @@ +from copy import copy +from typing import Dict, Literal, Optional, Union + +import torch +import torch.distributed as dist +from megatron.core import mpu +from megatron.training import get_args +from tqdm import tqdm + +from swift.llm import deep_getattr, get_model_tokenizer, save_checkpoint +from swift.utils import disable_safe_ddp_context_use_barrier, get_logger, is_last_rank +from ..tuners import LoraParallelLinear +from ..utils import LazyTensor, SafetensorLazyLoader, StreamingSafetensorSaver + +logger = get_logger() + + +class GPTBridge: + lm_layers_prefix = 'model.layers' # HF model + + def __init__(self, disable_tqmd: bool = False): + self.args = get_args() + self.disable_tqmd = disable_tqmd + self._target_device = None + self._only_last_rank = False + self._peft_target_modules = set() + self._is_peft_format = False + model_info = self.args.model_info + with torch.device('meta'), disable_safe_ddp_context_use_barrier(): + self.hf_model, self.processor = get_model_tokenizer( + model_info.model_dir, model_type=model_info.model_type, return_dummy_model=True) + self.hf_layers = deep_getattr(self.hf_model, self.lm_layers_prefix) + self.tp_size = self.args.tensor_model_parallel_size + self.pp_size = self.args.pipeline_model_parallel_size + self.etp_size = self.args.expert_tensor_parallel_size + self.ep_size = self.args.expert_model_parallel_size + + self.tp_group = mpu.get_tensor_model_parallel_group() + self.pp_group = mpu.get_pipeline_model_parallel_group() + self.etp_group = mpu.get_expert_tensor_parallel_group() + self.ep_group = mpu.get_expert_model_parallel_group() + + self.tp_rank = mpu.get_tensor_model_parallel_rank() + self.pp_rank = mpu.get_pipeline_model_parallel_rank() + self.etp_rank = mpu.get_expert_tensor_parallel_rank() + self.ep_rank = mpu.get_expert_model_parallel_rank() + + @staticmethod + def _get_tp_split_dim(mg_key: str) -> Optional[int]: + if 'lora_A' not in mg_key and 'lora_B' not in mg_key: + key, suffix = mg_key.rsplit('.', 2)[-2:] + if suffix == 'layer_norm_weight': + return + if key in {'word_embeddings', 'output_layer', 'linear_qkv'}: + return 0 + elif key in {'linear_proj', 'linear_fc1', 'linear_fc2'}: + # linear_fc1 shape [2, X, Y] + return 1 + else: + mg_key_splited = mg_key.rsplit('.', 3) + key, lora_name = mg_key_splited[:2] + if lora_name == 'lora_A': + if key in {'linear_proj', 'linear_fc2'}: + return 1 + elif lora_name == 'lora_B': + if key in {'word_embeddings', 'output_layer', 'linear_qkv'}: + return 0 + elif key in {'linear_fc1'}: + # linear_fc1 shape [2, X, Y] + return 1 + + def _set_weight( + self, + mg_param: torch.Tensor, + hf_weight: torch.Tensor, + mg_key: str, + offset: float = 0, + is_expert: bool = False, + ): + # tp/etp + tp_dim = self._get_tp_split_dim(mg_key) + hf_weight = hf_weight.to(device=mg_param.device, dtype=mg_param.dtype) + tp_size = self.etp_size if is_expert else self.tp_size + tp_rank = self.etp_rank if is_expert else self.tp_rank + tp_group = self.etp_group if is_expert else self.tp_group + if tp_dim is not None and tp_size > 1: + if tp_rank == 0: + splited_weights = [t.contiguous() for t in hf_weight.chunk(tp_size, dim=tp_dim)] + else: + splited_weights = None + tensor = torch.empty_like(mg_param.data) + dist.scatter( + tensor, + splited_weights, + src=dist.get_global_rank(tp_group, 0), + group=tp_group, + ) + del splited_weights + else: + tensor = hf_weight + if offset: + tensor = tensor + offset + mg_param.data.copy_(tensor) + + def _get_weight(self, mg_weight: torch.Tensor, mg_key: str, offset: int = 0, is_expert: bool = False): + # tp/etp + tp_dim = self._get_tp_split_dim(mg_key) + tensor = mg_weight + tp_size = self.etp_size if is_expert else self.tp_size + tp_group = self.etp_group if is_expert else self.tp_group + if tensor is not None and tp_dim is not None and tp_size > 1: + if tp_dim == 0: + # save memory + tensor_shape = list(tensor.shape) + tensor_shape[0] *= tp_size + output = tensor.new_empty(tensor_shape) + dist.all_gather_into_tensor( + output, + tensor, + group=tp_group, + ) + tensor = output + else: + output = [torch.empty_like(tensor) for _ in range(tp_size)] + dist.all_gather( + output, + tensor, + group=tp_group, + ) + tensor = torch.cat(output, dim=tp_dim) + del output + # pp/ep + for parallel_state in ['ep', 'pp']: + if parallel_state == 'pp' and self.pp_size > 1: + parallel_group = self.pp_group + parallel_rank = self.pp_rank + elif parallel_state == 'ep' and is_expert and self.ep_size > 1: + parallel_group = self.ep_group + parallel_rank = self.ep_rank + else: + continue + src_rank = torch.tensor([0 if tensor is None else parallel_rank], dtype=torch.int64, device='cuda') + dist.all_reduce(src_rank, group=parallel_group) + src_rank = dist.get_global_rank(parallel_group, src_rank.item()) + meta_data = torch.zeros(10, dtype=torch.int64, device='cuda') + dtype_mapping = {torch.float64: 0, torch.float32: 1, torch.float16: 2, torch.bfloat16: 3} + dtype_mapping_r = {v: k for k, v in dtype_mapping.items()} + if tensor is None: + dist.broadcast(meta_data, src=src_rank, group=parallel_group) + shape = meta_data[1:1 + meta_data[0]].tolist() + dtype = dtype_mapping_r[meta_data[-1].item()] + tensor = torch.empty(shape, device='cuda', dtype=dtype) + dist.broadcast(tensor, src=src_rank, group=parallel_group) + else: + meta_data[0] = tensor.ndim + meta_data[1:1 + tensor.ndim] = torch.tensor(tensor.shape, dtype=torch.int64, device='cuda') + meta_data[-1] = dtype_mapping[tensor.dtype] + dist.broadcast(meta_data, src=src_rank, group=parallel_group) + dist.broadcast(tensor, src=src_rank, group=parallel_group) + assert tensor is not None, f'mg_key: {mg_key}' + if offset: + tensor = tensor + offset + if self._target_device is not None: + tensor = tensor.to(device=self._target_device) + if self._only_last_rank and not is_last_rank(): + tensor = None + return tensor + + def _set_state_dict(self, + mg_module, + mg_key: str, + hf_state_dict, + hf_key: str, + to_mcore: bool, + offset: float = 0, + is_expert: bool = False): + module_key, param_key = mg_key.rsplit('.', 1) + sub_module = deep_getattr(mg_module, module_key) + if isinstance(sub_module, LoraParallelLinear) and self._is_peft_format and param_key != 'layer_norm_weight': + if to_mcore: + hf_module_key, hf_param_key = hf_key.rsplit('.', 1) + lora_A_key = f'{module_key}.lora_A.default.{param_key}' + lora_B_key = f'{module_key}.lora_B.default.{param_key}' + mg_lora_A = deep_getattr(mg_module, f'{lora_A_key}') + mg_lora_B = deep_getattr(mg_module, f'{lora_B_key}') + hf_lora_A = hf_state_dict[f'{hf_module_key}.lora_A.{hf_param_key}'].load() + hf_lora_B = hf_state_dict[f'{hf_module_key}.lora_B.{hf_param_key}'].load() + self._set_weight(mg_lora_A, hf_lora_A, lora_A_key, offset, is_expert) + self._set_weight(mg_lora_B, hf_lora_B, lora_B_key, offset, is_expert) + else: + hf_module_key, hf_param_key = hf_key.rsplit('.', 1) + lora_A_key = f'{module_key}.lora_A.default.{param_key}' + lora_B_key = f'{module_key}.lora_B.default.{param_key}' + lora_A_tensor = deep_getattr(mg_module, f'{lora_A_key}.data') + lora_B_tensor = deep_getattr(mg_module, f'{lora_B_key}.data') + hf_lora_A_key = f'{hf_module_key}.lora_A.{hf_param_key}' + hf_lora_B_key = f'{hf_module_key}.lora_B.{hf_param_key}' + lora_A = self._get_weight(lora_A_tensor, lora_A_key, offset, is_expert) + lora_B = self._get_weight(lora_B_tensor, lora_B_key, offset, is_expert) + if lora_A is not None: + self._peft_target_modules.add(hf_module_key) + hf_state_dict[hf_lora_A_key] = lora_A + hf_state_dict[hf_lora_B_key] = lora_B + elif not self._is_peft_format: + if isinstance(sub_module, LoraParallelLinear): + mg_param = deep_getattr(sub_module, f'base_layer.{param_key}') + else: + mg_param = deep_getattr(sub_module, param_key) + if to_mcore: + assert mg_param is not None, f'mg_module: {mg_module}, mg_key: {mg_key}' + hf_weight = hf_state_dict[hf_key].load() + self._set_weight(mg_param, hf_weight, mg_key, offset, is_expert) + else: + weight = self._get_weight(None if mg_param is None else mg_param.data, mg_key, offset, is_expert) + if weight is not None: + hf_state_dict[hf_key] = weight + + @staticmethod + def _remove_prefix(state_dict, prefix: str): + if not prefix: + return state_dict + return {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)} + + @staticmethod + def _add_prefix(state_dict, prefix: str): + if not prefix: + return state_dict + return {f'{prefix}{k}': v for k, v in state_dict.items()} + + @staticmethod + def _filter_prefix(state_dict, prefix: str): + if not prefix: + return state_dict + return {k: v for k, v in state_dict.items() if k.startswith(prefix)} + + @staticmethod + def _is_moe(state_dict): + for k, v in state_dict.items(): + if 'experts.' in k: + return True + return False + + def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool): + if to_mcore: + hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) + else: + hf_state_dict = {} + hf_attn = self.hf_layers[layer_idx].self_attn + args = self.args + num_query_groups = (args.num_query_groups if args.group_query_attention else args.num_attention_heads) + if to_mcore: + if isinstance(mg_attn.linear_qkv, LoraParallelLinear): + lora_A = hf_state_dict['q_proj.lora_A.weight'].load() + assert (lora_A == hf_state_dict['k_proj.lora_A.weight'].load()).all() and ( + lora_A == hf_state_dict['v_proj.lora_A.weight'].load() + ).all(), 'Need to ensure QKV\'s lora_A are consistent' + q_lora_B = hf_state_dict['q_proj.lora_B.weight'].load() + lora_B = torch.cat([ + q_lora_B.reshape((num_query_groups, -1, q_lora_B.shape[-1])), + hf_state_dict['k_proj.lora_B.weight'].load().reshape((num_query_groups, -1, q_lora_B.shape[-1])), + hf_state_dict['v_proj.lora_B.weight'].load().reshape((num_query_groups, -1, q_lora_B.shape[-1])), + ], + dim=1).reshape((-1, q_lora_B.shape[-1])) + self._set_weight(mg_attn.linear_qkv.lora_A['default'].weight, lora_A, 'linear_qkv.lora_A.weight') + self._set_weight(mg_attn.linear_qkv.lora_B['default'].weight, lora_B, 'linear_qkv.lora_B.weight') + else: + linear_qkv_weight = torch.cat([ + hf_state_dict['q_proj.weight'].load().reshape((num_query_groups, -1, args.hidden_size)), + hf_state_dict['k_proj.weight'].load().reshape((num_query_groups, -1, args.hidden_size)), + hf_state_dict['v_proj.weight'].load().reshape((num_query_groups, -1, args.hidden_size)), + ], + dim=1).reshape((-1, args.hidden_size)) + self._set_weight(mg_attn.linear_qkv.weight, linear_qkv_weight, 'linear_qkv.weight') + else: + q_dim, kv_dim = hf_attn.q_proj.weight.shape[0] // num_query_groups, hf_attn.k_proj.weight.shape[ + 0] // num_query_groups + is_lora = False if mg_attn is None else isinstance(mg_attn.linear_qkv, + LoraParallelLinear) and self._is_peft_format + if self.pp_size > 1: + dist.all_reduce(is_lora, group=self.pp_group) + if is_lora: + lora_A = self._get_weight(None if mg_attn is None else mg_attn.linear_qkv.lora_A['default'].weight.data, + 'linear_qkv.lora_A.default.weight') + lora_B = self._get_weight(None if mg_attn is None else mg_attn.linear_qkv.lora_B['default'].weight.data, + 'linear_qkv.lora_B.default.weight') + if lora_A is not None: + self._peft_target_modules.update({'q_proj', 'k_proj', 'v_proj'}) + for key in ['q_proj', 'k_proj', 'v_proj']: + hf_state_dict[f'{key}.lora_A.weight'] = lora_A.clone() + lora_B = lora_B.reshape((num_query_groups, -1, lora_B.shape[-1])) + hf_state_dict['q_proj.lora_B.weight'] = lora_B[:, :q_dim, :].reshape(-1, lora_B.shape[-1]).clone() + hf_state_dict['k_proj.lora_B.weight'] = lora_B[:, + q_dim:-kv_dim, :].reshape(-1, + lora_B.shape[-1]).clone() + hf_state_dict['v_proj.lora_B.weight'] = lora_B[:, -kv_dim:, :].reshape(-1, lora_B.shape[-1]).clone() + else: + mg_attn_weight = self._get_weight(None if mg_attn is None else mg_attn.linear_qkv.weight.data, + 'linear_qkv.weight') + if mg_attn_weight is not None: + mg_attn_weight = mg_attn_weight.reshape((num_query_groups, -1, args.hidden_size)) + hf_state_dict['q_proj.weight'] = mg_attn_weight[:, :q_dim, :].reshape(-1, args.hidden_size).clone() + hf_state_dict['k_proj.weight'] = mg_attn_weight[:, + q_dim:-kv_dim, :].reshape(-1, + args.hidden_size).clone() + hf_state_dict['v_proj.weight'] = mg_attn_weight[:, -kv_dim:, :].reshape(-1, + args.hidden_size).clone() + del mg_attn_weight + self._set_state_dict(mg_attn, 'linear_proj.weight', hf_state_dict, 'o_proj.weight', to_mcore) + + # Copy bias + if args.add_qkv_bias and not self._is_peft_format: + if to_mcore: + linear_qkv_bias = torch.cat([ + hf_state_dict['q_proj.bias'].load().reshape((num_query_groups, -1)), + hf_state_dict['k_proj.bias'].load().reshape((num_query_groups, -1)), + hf_state_dict['v_proj.bias'].load().reshape((num_query_groups, -1)), + ], + dim=1).reshape(-1) + self._set_weight(mg_attn.linear_qkv.bias, linear_qkv_bias, 'linear_qkv.bias') + else: + mg_attn_bias = self._get_weight(None if mg_attn is None else mg_attn.linear_qkv.bias.data, + 'linear_qkv.bias') + if mg_attn_bias is not None: + mg_attn_bias = mg_attn_bias.reshape((num_query_groups, -1)) + hf_state_dict['q_proj.bias'] = mg_attn_bias[:, :q_dim].reshape(-1).clone() + hf_state_dict['k_proj.bias'] = mg_attn_bias[:, q_dim:-kv_dim].reshape(-1).clone() + hf_state_dict['v_proj.bias'] = mg_attn_bias[:, -kv_dim:].reshape(-1).clone() + if args.qk_layernorm: + hf_q_norm_key = 'q_norm.weight' if hasattr(hf_attn, 'q_norm') else 'query_layernorm.weight' + hf_k_norm_key = 'k_norm.weight' if hasattr(hf_attn, 'k_norm') else 'key_layernorm.weight' + self._set_state_dict(mg_attn, 'q_layernorm.weight', hf_state_dict, hf_q_norm_key, to_mcore) + self._set_state_dict(mg_attn, 'k_layernorm.weight', hf_state_dict, hf_k_norm_key, to_mcore) + if not to_mcore: + hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) + return hf_state_dict + + def _set_moe_state( + self, + mg_mlp, + hf_state_dict, + hf_prefix: str, + layer_idx: int, + to_mcore: bool, + ): + if to_mcore: + hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) + else: + hf_state_dict = {} + + hf_mlp = self.hf_layers[layer_idx].mlp + hf_gate_key = 'gate.wg.weight' if hasattr(hf_mlp.gate, 'wg') else 'gate.weight' + self._set_state_dict(mg_mlp, 'router.weight', hf_state_dict, hf_gate_key, to_mcore) + if self.args.moe_router_enable_expert_bias: + self._set_state_dict(mg_mlp, 'router.expert_bias', hf_state_dict, 'gate.e_score_correction_bias', to_mcore) + + if self.args.moe_shared_expert_intermediate_size: + for key in ['shared_expert', 'shared_experts', 'shared_mlp']: + if hasattr(hf_mlp, key): + hf_shared_expert_prefix = f'{key}.' + shared_expert = getattr(hf_mlp, key) + hf_state_dict.update( + self._set_mlp_state( + mg_mlp.shared_experts, + hf_state_dict, + hf_shared_expert_prefix, + layer_idx, + to_mcore, + hf_mlp=shared_expert)) + if hasattr(hf_mlp, 'shared_expert_gate'): + self._set_state_dict(mg_mlp, 'shared_experts.gate_weight', hf_state_dict, 'shared_expert_gate.weight', + to_mcore) + for ep_rank in range(self.ep_size): + mg_experts = mg_mlp.experts + expert_available = ep_rank == self.ep_rank + if not expert_available: + if to_mcore: + continue + else: + mg_experts = None + hf_state_dict.update( + self._set_expert_state(mg_experts, hf_state_dict, 'experts.', layer_idx, to_mcore, ep_rank)) + if not to_mcore: + hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) + return hf_state_dict + + def _set_expert_state( + self, + mg_mlp, + hf_state_dict, + hf_prefix: str, + layer_idx: int, + to_mcore: bool, + ep_rank: int, + ): + if to_mcore: + hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) + else: + hf_state_dict = {} + hf_experts = self.hf_layers[layer_idx].mlp.experts + num_local_experts = self.args.num_experts // self.ep_size + # if hf_group_idx is not None: + # res[fc1_key] = hf_state_dict['gate_up_proj'][hf_group_idx].t() + # res[fc2_key] = hf_state_dict['down_proj'][hf_group_idx].t() + # else: + if to_mcore: + # linear_fc1 + fc1_weight = mg_mlp.linear_fc1.weight0 + fc1_weight = fc1_weight.new_empty(num_local_experts * 2, fc1_weight.shape[0] // 2, fc1_weight.shape[1]) + if hasattr(hf_experts[0], 'gate_up_proj'): + gate_up_proj_weight = torch.concat([ + hf_state_dict[f'{i + ep_rank * num_local_experts}.gate_up_proj.weight'].load() + for i in range(num_local_experts) + ], + dim=0) + else: + weight_list = [] + start_idx = ep_rank * num_local_experts + for i in range(num_local_experts): + gate_proj_weight = hf_state_dict[f'{start_idx + i}.gate_proj.weight'].load() + up_proj_weight = hf_state_dict[f'{start_idx + i}.up_proj.weight'].load() + weight_list.append(torch.stack([gate_proj_weight, up_proj_weight], dim=0)) + gate_up_proj_weight = torch.concat(weight_list, dim=0) + del weight_list + self._set_weight(fc1_weight, gate_up_proj_weight, 'linear_fc1.weight', is_expert=True) + fc1_weight = fc1_weight.view(num_local_experts, -1, fc1_weight.shape[-1]) + for i in range(num_local_experts): + getattr(mg_mlp.linear_fc1, f'weight{i}').data.copy_(fc1_weight[i].view(-1, fc1_weight.shape[-1])) + del fc1_weight + # linear_fc2 + fc2_weight = mg_mlp.linear_fc2.weight0 + fc2_weight = fc2_weight.new_empty(num_local_experts * fc2_weight.shape[0], fc2_weight.shape[1]) + down_proj_weight = torch.concat([ + hf_state_dict[f'{i + ep_rank * num_local_experts}.down_proj.weight'].load() + for i in range(num_local_experts) + ], + dim=0) + self._set_weight(fc2_weight, down_proj_weight, 'linear_fc2.weight', is_expert=True) + fc2_weight = fc2_weight.view(num_local_experts, -1, fc2_weight.shape[-1]) + for i in range(num_local_experts): + getattr(mg_mlp.linear_fc2, f'weight{i}').data.copy_(fc2_weight[i]) + else: + if mg_mlp is None: + fc1_weight = None + else: + fc1_weight = torch.concat([getattr(mg_mlp.linear_fc1, f'weight{i}') for i in range(num_local_experts)], + dim=0) + fc1_weight = fc1_weight.view(num_local_experts * 2, -1, fc1_weight.shape[1]) + gate_up_proj_weight = self._get_weight(fc1_weight, 'linear_fc1.weight', is_expert=True) + del fc1_weight + if gate_up_proj_weight is not None: + gate_up_proj_weight = gate_up_proj_weight.view(num_local_experts, 2, -1, gate_up_proj_weight.shape[-1]) + for i in range(num_local_experts): + hf_i = i + ep_rank * num_local_experts + if hasattr(hf_experts[i], 'gate_up_proj'): + hf_state_dict[f'{hf_i}.gate_up_proj.weight'] = gate_up_proj_weight[i].view( + -1, gate_up_proj_weight.shape[-1]).clone() + else: + hf_state_dict[f'{hf_i}.gate_proj.weight'] = gate_up_proj_weight[i][0].clone() + hf_state_dict[f'{hf_i}.up_proj.weight'] = gate_up_proj_weight[i][1].clone() + del gate_up_proj_weight + # linear_fc2 + if mg_mlp is None: + fc2_weight = None + else: + fc2_weight = torch.concat([getattr(mg_mlp.linear_fc2, f'weight{i}') for i in range(num_local_experts)], + dim=0) + fc2_weight = fc2_weight.view(num_local_experts * 2, -1, fc2_weight.shape[1]) + down_proj_weight = self._get_weight(fc2_weight, 'linear_fc2.weight', is_expert=True) + del fc2_weight + if down_proj_weight is not None: + down_proj_weight = down_proj_weight.view(num_local_experts, -1, down_proj_weight.shape[-1]) + for i in range(num_local_experts): + hf_i = i + ep_rank * num_local_experts + hf_state_dict[f'{hf_i}.down_proj.weight'] = down_proj_weight[i].view( + -1, down_proj_weight.shape[-1]).clone() + if not to_mcore: + hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) + return hf_state_dict + + def _set_mlp_state(self, mg_mlp, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool, hf_mlp=None): + if to_mcore: + hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) + else: + hf_state_dict = {} + if hf_mlp is None: + hf_mlp = self.hf_layers[layer_idx].mlp + if to_mcore: + if isinstance(mg_mlp.linear_fc1, LoraParallelLinear): + mg_lora_B = mg_mlp.linear_fc1.lora_B['default'].weight + mg_lora_B = mg_lora_B.new_empty(2, mg_lora_B.shape[0] // 2, mg_lora_B.shape[-1]) + if hasattr(hf_mlp, 'gate_up_proj'): + lora_A = hf_state_dict['gate_up_proj.lora_A.weight'].load() + lora_B = hf_state_dict['gate_up_proj.lora_B.weight'].load() + else: + lora_A = hf_state_dict['gate_proj.lora_A.weight'].load() + assert (lora_A == hf_state_dict['up_proj.lora_A.weight'].load() + ).all(), 'Need to ensure lora_A consistency between gate_proj and up_proj' + gate_lora_B = hf_state_dict['gate_proj.lora_B.weight'].load() + up_lora_B = hf_state_dict['up_proj.lora_B.weight'].load() + lora_B = torch.stack([gate_lora_B, up_lora_B], dim=0) + self._set_weight(mg_mlp.linear_fc1.lora_A['default'].weight, lora_A, 'linear_fc1.lora_A.default.weight') + self._set_weight(mg_lora_B, lora_B, 'linear_fc1.lora_B.default.weight') + mg_mlp.linear_fc1.lora_B['default'].weight.data.copy_(mg_lora_B.view(-1, mg_lora_B.shape[-1])) + else: + fc1_weight = mg_mlp.linear_fc1.weight + fc1_weight = fc1_weight.new_empty(2, fc1_weight.shape[0] // 2, fc1_weight.shape[1]) + if hasattr(hf_mlp, 'gate_up_proj'): + gate_up_proj_weight = hf_state_dict['gate_up_proj.weight'].load().view( + 2, -1, gate_up_proj_weight.shape[-1]) + else: + gate_proj_weight = hf_state_dict['gate_proj.weight'].load() + up_proj_weight = hf_state_dict['up_proj.weight'].load() + gate_up_proj_weight = torch.stack([gate_proj_weight, up_proj_weight], dim=0) + self._set_weight(fc1_weight, gate_up_proj_weight, 'linear_fc1.weight') + mg_mlp.linear_fc1.weight.data.copy_(fc1_weight.view(-1, fc1_weight.shape[-1])) + else: + is_lora = False if mg_mlp is None else isinstance(mg_mlp.linear_fc1, + LoraParallelLinear) and self._is_peft_format + if self.pp_size > 1: + dist.all_reduce(is_lora, group=self.pp_group) + if is_lora: + if mg_mlp is None: + lora_A = None + lora_B = None + else: + lora_A = mg_mlp.linear_fc1.lora_A['default'].weight + lora_B = mg_mlp.linear_fc1.lora_B['default'].weight + lora_B = lora_B.view(2, lora_B.shape[0] // 2, lora_B.shape[1]) + lora_A = self._get_weight(lora_A, 'linear_fc1.lora_A.default.weight') + lora_B = self._get_weight(lora_B, 'linear_fc1.lora_B.default.weight') + if lora_A is not None: + if hasattr(hf_mlp, 'gate_up_proj'): + self._peft_target_modules.update({'gate_up_proj'}) + hf_state_dict['gate_up_proj.lora_A.weight'] = lora_A.clone() + hf_state_dict['gate_up_proj.lora_B.weight'] = lora_B.clone() + else: + self._peft_target_modules.update({'gate_proj', 'up_proj'}) + hf_state_dict['gate_proj.lora_A.weight'] = lora_A.clone() + hf_state_dict['up_proj.lora_A.weight'] = lora_A.clone() + hf_state_dict['gate_proj.lora_B.weight'] = lora_B[0].clone() + hf_state_dict['up_proj.lora_B.weight'] = lora_B[1].clone() + else: + if mg_mlp is None: + fc1_weight = None + else: + fc1_weight = mg_mlp.linear_fc1.weight + fc1_weight = fc1_weight.view(2, fc1_weight.shape[0] // 2, fc1_weight.shape[1]) + gate_up_proj_weight = self._get_weight(None if fc1_weight is None else fc1_weight, 'linear_fc1.weight') + if gate_up_proj_weight is not None: + if hasattr(hf_mlp, 'gate_up_proj'): + hf_state_dict['gate_up_proj.weight'] = gate_up_proj_weight.view( + -1, gate_up_proj_weight.shape[-1]).clone() + else: + hf_state_dict['gate_proj.weight'] = gate_up_proj_weight[0].clone() + hf_state_dict['up_proj.weight'] = gate_up_proj_weight[1].clone() + self._set_state_dict(mg_mlp, 'linear_fc2.weight', hf_state_dict, 'down_proj.weight', to_mcore) + if not to_mcore: + hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) + return hf_state_dict + + def _set_mla_attn_state( + self, + mg_attn, + hf_state_dict, + hf_prefix: str, + layer_idx: int, + to_mcore: bool, + ): + if to_mcore: + hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) + else: + hf_state_dict = {} + self._set_state_dict(mg_attn, 'linear_proj.weight', hf_state_dict, 'o_proj.weight', to_mcore) + if self.args.q_lora_rank is None: + self._set_state_dict(mg_attn, 'linear_q_proj.weight', hf_state_dict, 'q_proj.weight', to_mcore) + else: + self._set_state_dict(mg_attn, 'linear_q_down_proj.weight', hf_state_dict, 'q_a_proj.weight', to_mcore) + self._set_state_dict(mg_attn, 'linear_q_up_proj.weight', hf_state_dict, 'q_b_proj.weight', to_mcore) + self._set_state_dict(mg_attn, 'linear_kv_down_proj.weight', hf_state_dict, 'kv_a_proj_with_mqa.weight', + to_mcore) + self._set_state_dict(mg_attn, 'linear_kv_up_proj.weight', hf_state_dict, 'kv_b_proj.weight', to_mcore) + if self.args.qk_layernorm: + self._set_state_dict(mg_attn, 'linear_kv_up_proj.layer_norm_weight', hf_state_dict, 'kv_a_layernorm.weight', + to_mcore) + if not to_mcore: + hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) + return hf_state_dict + + def _set_layer_attn(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool): + mg_attn = None if mg_layer is None else mg_layer.self_attention + if self.args.multi_latent_attention: + hf_state_dict.update(self._set_mla_attn_state(mg_attn, hf_state_dict, 'self_attn.', layer_idx, to_mcore)) + self._set_state_dict(mg_layer, 'input_layernorm.weight', hf_state_dict, 'input_layernorm.weight', to_mcore) + else: + hf_state_dict.update(self._set_attn_state(mg_attn, hf_state_dict, 'self_attn.', layer_idx, to_mcore)) + self._set_state_dict(mg_layer, 'self_attention.linear_qkv.layer_norm_weight', hf_state_dict, + 'input_layernorm.weight', to_mcore) + return hf_state_dict + + def _set_layer_mlp(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool): + hf_mlp = self.hf_layers[layer_idx].mlp + is_moe = self._is_moe(hf_mlp.state_dict()) + mg_mlp = None if mg_layer is None else mg_layer.mlp + if is_moe: + hf_state_dict.update(self._set_moe_state(mg_mlp, hf_state_dict, 'mlp.', layer_idx, to_mcore)) + self._set_state_dict(mg_layer, 'pre_mlp_layernorm.weight', hf_state_dict, 'post_attention_layernorm.weight', + to_mcore) + else: + hf_state_dict.update(self._set_mlp_state(mg_mlp, hf_state_dict, 'mlp.', layer_idx, to_mcore)) + self._set_state_dict(mg_layer, 'mlp.linear_fc1.layer_norm_weight', hf_state_dict, + 'post_attention_layernorm.weight', to_mcore) + return hf_state_dict + + def _set_layer_state(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool): + hf_prefix = f'{hf_prefix}{layer_idx}.' + if to_mcore: + hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) + else: + hf_state_dict = {} + hf_state_dict.update(self._set_layer_attn(mg_layer, hf_state_dict, layer_idx, to_mcore)) + hf_state_dict.update(self._set_layer_mlp(mg_layer, hf_state_dict, layer_idx, to_mcore)) + if not to_mcore: + hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) + return hf_state_dict + + def _convert(self, mg_models, hf_state_dict, hf_prefix: str, to_mcore: bool): + if to_mcore: + hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) + else: + hf_state_dict = {} + mg_models = iter(mg_models) + mg_model = next(mg_models) + if not to_mcore or mpu.is_pipeline_first_stage(ignore_virtual=False, vp_stage=mg_model.vp_stage): + self._set_state_dict(mg_model, 'embedding.word_embeddings.weight', hf_state_dict, + 'model.embed_tokens.weight', to_mcore) + if to_mcore: + yield + else: + yield from list(self._add_prefix(hf_state_dict, hf_prefix).items()) + hf_state_dict = {} + for layer_idx in tqdm( + range(self.args.num_layers), dynamic_ncols=True, desc='Converting: ', disable=self.disable_tqmd): + start_idx = mg_model.decoder.layers[0].layer_number - 1 + mg_layer_available = (start_idx <= layer_idx < mg_model.decoder.layers[-1].layer_number) + if mg_layer_available: + mg_layer = mg_model.decoder.layers[layer_idx - start_idx] + else: + if to_mcore: + continue + else: + mg_layer = None + if not to_mcore and self.pp_size > 1: + has_model = torch.tensor([mg_layer is not None], dtype=torch.bool, device='cuda') + dist.all_reduce(has_model, group=self.pp_group) + if not has_model: + mg_model = next(mg_models) + continue + res = self._set_layer_state(mg_layer, hf_state_dict, 'model.layers.', layer_idx, to_mcore) + if to_mcore: + yield + else: + yield from list(self._add_prefix(res, hf_prefix).items()) + hf_state_dict = {} + if not to_mcore or mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=mg_model.vp_stage): + if self.args.untie_embeddings_and_output_weights: + hf_lm_head_key = 'lm_head.weight' + if not to_mcore and self.args.task_type == 'seq_cls': + hf_lm_head_key = 'score.weight' + self._set_state_dict(mg_model, 'output_layer.weight', hf_state_dict, hf_lm_head_key, to_mcore) + self._set_state_dict(mg_model, 'decoder.final_layernorm.weight', hf_state_dict, 'model.norm.weight', + to_mcore) + if to_mcore: + yield + else: + yield from list(self._add_prefix(hf_state_dict, hf_prefix).items()) + + def load_weights(self, mg_model, hf_model_dir: str, is_peft_format: bool = False): + self._is_peft_format = is_peft_format + with SafetensorLazyLoader(hf_model_dir, is_peft_format=is_peft_format) as loader: + state_dict = loader.get_state_dict() + hf_prefix = 'base_model.model.' if is_peft_format else '' + list(self._convert([mg_model], state_dict, hf_prefix, True)) + + def export_weights(self, mg_models, target_device=None, only_last_rank: bool = False, is_peft_format: bool = False): + # TODO: modules_to_save + self._target_device = target_device + self._only_last_rank = only_last_rank + self._is_peft_format = is_peft_format + self._peft_target_modules = set() + hf_prefix = 'base_model.model.' if is_peft_format else '' + yield from self._convert(mg_models, {}, hf_prefix, False) + + def save_weights(self, mg_models, output_dir: str, is_peft_format: bool = False) -> None: + """Save the mg_model checkpoint in HF format""" + saver = StreamingSafetensorSaver( + save_dir=output_dir, max_shard_size=self.args.max_shard_size, is_peft_format=is_peft_format) + for k, v in self.export_weights( + mg_models, target_device='cpu', only_last_rank=True, is_peft_format=is_peft_format): + saver.add_tensor(k, v) + saver.finalize() + if is_last_rank(): + if is_peft_format: + peft_config = copy(mg_models[0].peft_config['default']) + peft_config.target_modules = self._peft_target_modules + peft_config.save_pretrained(output_dir) + else: + # TODO: new_special_tokens + self.hf_model.config.save_pretrained(output_dir) + save_checkpoint( + None, + self.processor, + output_dir, + model_dirs=[self.hf_model.model_info.model_dir], + additional_saved_files=self.hf_model.model_meta.additional_saved_files) + logger.info_if(f'Successfully saved `safetensors` model weights in `{output_dir}`.', cond=is_last_rank()) diff --git a/swift/megatron/model/mm_gpt/qwen3_vl.py b/swift/megatron/model/mm_gpt/qwen3_vl.py index 8220a15e41..bc19f17dc9 100644 --- a/swift/megatron/model/mm_gpt/qwen3_vl.py +++ b/swift/megatron/model/mm_gpt/qwen3_vl.py @@ -15,7 +15,7 @@ from swift.llm import ModelType, to_device from ..constant import MegatronModelType -from ..gpt.hf2mcore import set_layer_state as set_layer_state_hf2mcore +from ..gpt.hf2mcore import _add_prefix, _remove_prefix, convert_hf2mcore from ..gpt.mcore2hf import set_layer_state as set_layer_state_mcore2hf from ..mm_gpt_model import MultimodalGPTModel from ..register import register_megatron_model @@ -502,17 +502,14 @@ def __init__(self, *args, **kwargs): visual_cls=Qwen3Omni_Vit)) -def convert_hf2mcore_qwen3_vl(hf_model, mg_model): - language_model = hf_model.model.language_model - mg_language_model = mg_model.language_model +def convert_hf2mcore_qwen3_vl(state_dict, prefix=''): args = get_args() - mg_language_model.embedding.word_embeddings.weight.data.copy_(language_model.embed_tokens.weight) + mg_state_dict = {} if args.untie_embeddings_and_output_weights: - mg_language_model.output_layer.weight.data.copy_(hf_model.lm_head.weight) - mg_language_model.decoder.final_layernorm.weight.data.copy_(language_model.norm.weight) - for layer_idx in range(args.num_layers): - set_layer_state_hf2mcore(args, mg_language_model, language_model, layer_idx) - mg_model.visual.visual.load_state_dict(hf_model.model.visual.state_dict()) + mg_state_dict['language_model.output_layer.weight'] = state_dict['lm_head.weight'] + mg_state_dict.update(convert_hf2mcore(state_dict, 'language_model.')) + mg_state_dict.update(_add_prefix(_remove_prefix(state_dict, 'model.visual.'), 'visual.visual.')) + return _add_prefix(mg_state_dict, prefix) def convert_mcore2hf_qwen3_vl(hf_model, mg_model): diff --git a/swift/megatron/model/model_provider.py b/swift/megatron/model/model_provider.py index 814cd82213..392785e085 100644 --- a/swift/megatron/model/model_provider.py +++ b/swift/megatron/model/model_provider.py @@ -14,14 +14,12 @@ if TYPE_CHECKING: from .gpt_model import GPTModel - from .mm_gpt_model import MultimodalGPTModel # Code borrowed from NVIDIA/Megatron-LM -def model_provider( - pre_process=True, - post_process=True, - vp_stage: Optional[int] = None) -> Union['GPTModel', 'MultimodalGPTModel', megatron.legacy.model.GPTModel]: +def model_provider(pre_process=True, + post_process=True, + vp_stage: Optional[int] = None) -> Union['GPTModel', megatron.legacy.model.GPTModel]: """Builds the model. If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model. diff --git a/swift/megatron/model/register.py b/swift/megatron/model/register.py index 93f892c2e8..877a09db92 100644 --- a/swift/megatron/model/register.py +++ b/swift/megatron/model/register.py @@ -1,12 +1,15 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from argparse import ArgumentParser from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Type +from typing import Callable, List, Optional, Type import torch.nn as nn -from transformers import PretrainedConfig from swift.llm import MODEL_MAPPING +from .constant import MLLMMegatronModelType +from .gpt_bridge import GPTBridge +from .gpt_model import GPTModel +from .mm_gpt_model import MultimodalGPTModel from .model_provider import model_provider as model_provider_func MEGATRON_MODEL_MAPPING = {} @@ -17,17 +20,18 @@ class MegatronModelMeta: megatron_model_type: str model_types: List[str] - convert_mcore2hf: Callable[[nn.Module, nn.Module], None] - convert_hf2mcore: Callable[[nn.Module, nn.Module], None] - - model_cls: Type[nn.Module] - convert_hf_config: Callable[[PretrainedConfig], Dict[str, Any]] + is_multimodal: bool = False + bridge_cls: Type[GPTBridge] = GPTBridge get_transformer_layer_spec: Optional[Callable] = None model_provider: Callable[[], nn.Module] = model_provider_func visual_cls: Optional[Type[nn.Module]] = None extra_args_provider: Optional[Callable[[ArgumentParser], ArgumentParser]] = None + @property + def model_cls(self): + return MultimodalGPTModel if self.is_multimodal else GPTModel + def register_megatron_model(megatron_model_meta: MegatronModelMeta, *, exist_ok: bool = False): megatron_model_type = megatron_model_meta.megatron_model_type @@ -36,7 +40,8 @@ def register_megatron_model(megatron_model_meta: MegatronModelMeta, *, exist_ok: model_meta.support_megatron = True if not exist_ok and megatron_model_type in MEGATRON_MODEL_MAPPING: raise ValueError(f'The `{megatron_model_type}` has already been registered in the MODEL_MAPPING.') - + if megatron_model_type in MLLMMegatronModelType.__dict__: + megatron_model_meta.is_multimodal = True MEGATRON_MODEL_MAPPING[megatron_model_type] = megatron_model_meta diff --git a/swift/megatron/train/sft.py b/swift/megatron/train/sft.py index 15b81434f9..31978b343c 100644 --- a/swift/megatron/train/sft.py +++ b/swift/megatron/train/sft.py @@ -7,10 +7,10 @@ from swift.llm import TEMPLATE_MAPPING from swift.llm.train import SwiftSft -from swift.utils import get_logger, is_master, plot_images +from swift.utils import get_logger, is_last_rank, plot_images from ..argument import MegatronTrainArguments from ..trainers import MegatronTrainer -from ..utils import patch_megatron_tokenizer +from ..utils import get_padding_to from .utils import build_streaming_dataloader logger = get_logger() @@ -35,22 +35,14 @@ def __init__(self, args: Optional[Union[List[str], MegatronTrainArguments]] = No with torch.device('meta'): self.model, self.processor = args.get_model_processor(**kwargs) self._prepare_template() - patch_megatron_tokenizer(self.processor) args.init_model_args(self.tokenizer, self.processor.model_info.config) args.save_args(args.save) self.template.use_megatron = True self.trainer = self.prepare_trainer() def _get_data_collator(self): - args = self.args data_collator = self.template.data_collator - padding_to = None - if args.tensor_model_parallel_size > 1 and args.sequence_parallel: - padding_to = args.tensor_model_parallel_size - if args.context_parallel_size > 1: - padding_to = (padding_to or 1) * args.context_parallel_size - if args.fp8_format: - padding_to = max((padding_to or 1) * 8, 16) + padding_to = get_padding_to(self.args) logger.info(f'padding_to: {padding_to}') data_collator = partial(data_collator, padding_to=padding_to) return data_collator @@ -69,7 +61,7 @@ def run(self): self.trainer.train(train_dataset, val_dataset, data_collator) finally: # Visualization - if is_master(): + if is_last_rank(): images_dir = os.path.join(args.save, 'images') logger.info(f'images_dir: {images_dir}') plot_images(images_dir, args.tensorboard_dir) diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 52a6f7077f..b483977214 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -1,6 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import collections import os +import shutil import time from abc import ABC, abstractmethod from contextlib import contextmanager @@ -11,7 +12,6 @@ import torch import torch.nn from megatron.core import mpu -from megatron.core.dist_checkpointing.mapping import ShardedTensorFactory from megatron.core.enums import ModelType from megatron.core.num_microbatches_calculator import get_num_microbatches from megatron.core.pipeline_parallel import get_forward_backward_func @@ -32,7 +32,8 @@ from swift.plugin import MeanMetric from swift.trainers import SwiftMixin from swift.utils import JsonlWriter, deep_getattr, format_time, get_logger -from ..utils import adapter_state_dict_context, copy_original_module_weight, prepare_mcore_model +from ..tuners import LoraParallelLinear +from ..utils import adapter_state_dict_context, copy_original_module_weight, patch_merge_fn, prepare_mcore_model from .utils import (get_batch_on_this_cp_rank, get_batch_on_this_tp_rank, get_packed_seq_params, get_swift_datasets_provider) @@ -47,6 +48,7 @@ def __init__(self, args, template): self.stimer = StragglerDetector() self.unwrapped_models = [] self.peft_models = [] + self._bridge = None logging_path = os.path.join(args.save, 'logging.jsonl') logger.info(f'logging_path: {logging_path}') self.jsonl_writer = JsonlWriter(logging_path, enable_async=True, write_on_rank='last') # for evaluate @@ -61,6 +63,12 @@ def _get_mean_metric(): } self.megatron_core_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') + @property + def bridge(self): + if self._bridge is None: + self._bridge = self.args.megatron_model_meta.bridge_cls() + return self._bridge + @contextmanager def _get_iters(self, train_dataset, val_dataset): origin_initialize_megatron = training.initialize_megatron @@ -126,29 +134,6 @@ def new_cyclic_iter(self, iterable): def _replace_data_iterator(self, data_iterator, model): return data_iterator - @staticmethod - def _patch_merge_fn(state_dict_model): - # https://github.com/NVIDIA/Megatron-LM/issues/1380 - - def sh_ten_merge_fn(sub_state_dict): - with torch.no_grad(): - shared_storage = sub_state_dict[0].untyped_storage() - if all(shared_storage.data_ptr() == tensor.untyped_storage().data_ptr() for tensor in sub_state_dict): - element_size = sub_state_dict[0].element_size() - total_numel = sum(tensor.numel() for tensor in sub_state_dict) - if shared_storage.nbytes() == total_numel * element_size: - dim_0 = sum(tensor.shape[0] for tensor in sub_state_dict) - shape = (dim_0, ) + sub_state_dict[0].shape[1:] - combined_tensor = torch.empty( - shape, dtype=sub_state_dict[0].dtype, - device=sub_state_dict[0].device).set_(shared_storage, 0, shape) - return combined_tensor - return torch.cat(sub_state_dict) - - for v in state_dict_model.values(): - if isinstance(v, ShardedTensorFactory) and 'apply_swiglu_sharded_factory' in v.merge_fn.__qualname__: - v.merge_fn = sh_ten_merge_fn - def _load_adapter_base_checkpoint(self, *_args, **kwargs): adapter_name = kwargs.pop('adapter_name', None) or 'ref_adapter' sharded_state_dict = kwargs.get('sharded_state_dict') @@ -169,7 +154,7 @@ def _load_adapter_base_checkpoint(self, *_args, **kwargs): v.key = v.key.replace(f'.{adapter_name}.', '.default.') state_dict_model[k] = v sharded_state_dict[model_k] = state_dict_model - self._patch_merge_fn(state_dict_model) + patch_merge_fn(state_dict_model) # TODO: check res = checkpointing.origin__load_base_checkpoint(*_args, **kwargs) for model_k in model_keys: state_dict = res[0][model_k] @@ -185,7 +170,7 @@ def _load_base_checkpoint(self, *_args, **kwargs): model_keys = [k for k in sharded_state_dict.keys() if k.startswith('model')] if self.args.train_type == 'full': for k in model_keys: - self._patch_merge_fn(sharded_state_dict[k]) + patch_merge_fn(sharded_state_dict[k]) return checkpointing.origin__load_base_checkpoint(*_args, **kwargs) mapping = {} for model_k in model_keys: @@ -211,7 +196,7 @@ def _load_base_checkpoint(self, *_args, **kwargs): v.key = v.key.replace('.modules_to_save.default', '') state_dict_model[k] = v sharded_state_dict[model_k] = state_dict_model - self._patch_merge_fn(state_dict_model) + patch_merge_fn(state_dict_model) res = checkpointing.origin__load_base_checkpoint(*_args, **kwargs) for model_k in model_keys: state_dict = res[0][model_k] @@ -251,13 +236,20 @@ def load_state_dict(self, state_dict, strict: bool = True, *args, **kwargs): def setup_model_and_optimizer(self, model_provider_func, model_type, *_args, **kwargs): - def new_model_provider_func(*args, **kwargs): - model = model_provider_func(*args, **kwargs) + args = get_args() + + def new_model_provider_func(*_args, **kwargs): + model = model_provider_func(*_args, **kwargs) + if args.load_hf_checkpoint: + self.bridge.load_weights(model, args.model_info.model_dir) self.unwrapped_models.append(model) - self.peft_models.append(prepare_mcore_model(model)) + peft_model = prepare_mcore_model(model) + if args.load_hf_checkpoint and args.train_type == 'lora' and args.adapters: + assert len(args.adapters) == 1, 'Currently only support one adapter' + self.bridge.load_weights(model, args.adapters[0], is_peft_format=True) + self.peft_models.append(peft_model) return model - args = get_args() self._init_multimodal_full(args) with self._patch_load_state_dict(self._load_base_checkpoint): model, optimizer, opt_param_scheduler = self._origin_setup_model_and_optimizer( @@ -724,9 +716,39 @@ def training_log(self, loss_dict, total_loss_dict, learning_rate, decoupled_lear return report_memory_flag - def save_checkpoint(self, *args, **kwargs): - with adapter_state_dict_context(): - return self._origin_save_checkpoint(*args, **kwargs) + def merge_lora_adapters(self): + """Merge LoRA adapters into base model weights for vLLM inference.""" + for model in self.unwrapped_models: + for module in model.modules(): + if isinstance(module, LoraParallelLinear): + # Merge all active adapters + module.merge() + + def unmerge_lora_adapters(self): + """Unmerge LoRA adapters to restore training state.""" + for model in self.unwrapped_models: + for module in model.modules(): + if isinstance(module, LoraParallelLinear): + # Unmerge to restore separate LoRA weights for training + module.unmerge() + + def save_checkpoint(self, iteration, *_args, **kwargs): + args = get_args() + if args.save_hf_checkpoint: + if args.train_type == 'lora' and args.merge_lora: + self.merge_lora_adapters() + output_dir = os.path.join(args.save, f'checkpoint-{iteration}') + save_peft_format = args.train_type == 'lora' and not args.merge_lora + self.bridge.save_weights(self.unwrapped_models, output_dir, is_peft_format=save_peft_format) + if is_last_rank(): + args_path = os.path.join(os.path.dirname(output_dir), 'args.json') + if os.path.exists(args_path): + shutil.copy(args_path, os.path.join(output_dir, 'args.json')) + if args.train_type == 'lora' and args.merge_lora: + self.unmerge_lora_adapters() + else: + with adapter_state_dict_context(): + return self._origin_save_checkpoint(iteration, *_args, **kwargs) def _patch_megatron(self): # support max_epochs diff --git a/swift/megatron/trainers/rlhf_mixin.py b/swift/megatron/trainers/rlhf_mixin.py index 5e65b64060..026f01e8f7 100644 --- a/swift/megatron/trainers/rlhf_mixin.py +++ b/swift/megatron/trainers/rlhf_mixin.py @@ -53,35 +53,6 @@ def null_ref_context(self): for m in self.peft_models: m.set_adapter('default') - @staticmethod - def _forward_step_helper(model, inputs): - args = get_args() - if mpu.is_pipeline_first_stage(): - micro_batch_size = 1 # use qkv_format 'thd' - seq_length = inputs['input_ids'].shape[1] - if args.sequence_parallel: - seq_length //= mpu.get_tensor_model_parallel_world_size() - recv_shape_buffer = torch.tensor([seq_length, micro_batch_size, args.hidden_size], - device=torch.cuda.current_device(), - dtype=torch.int64) - else: - recv_shape_buffer = torch.empty((3, ), device=torch.cuda.current_device(), dtype=torch.int64) - recv_from_prev_pipeline_rank_(recv_shape_buffer) - if not mpu.is_pipeline_last_stage(): - send_to_next_pipeline_rank(recv_shape_buffer) - shape = recv_shape_buffer.tolist() - - if not mpu.is_pipeline_first_stage(): - recv_buffer = torch.empty(shape, device=torch.cuda.current_device(), dtype=args.params_dtype) - recv_from_prev_pipeline_rank_(recv_buffer) - model.set_input_tensor(recv_buffer) - output_tensor = model(**inputs) - if not mpu.is_pipeline_last_stage(): - send_to_next_pipeline_rank(output_tensor) - output_tensor = None - - return output_tensor - def get_logps(self, output_tensor, labels, packed_seq_params, num_samples=None): args = get_args() per_token_logps = -output_tensor diff --git a/swift/megatron/tuners/lora.py b/swift/megatron/tuners/lora.py index f9ad78ef50..69b3a0a4ed 100644 --- a/swift/megatron/tuners/lora.py +++ b/swift/megatron/tuners/lora.py @@ -422,6 +422,39 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N if origin_device.type == 'cpu': self.to(device=origin_device) + def unmerge(self) -> None: + """ + Unmerge all merged adapter weights from the base weights. + This method reverses the merge operation by subtracting the LoRA delta weights + from the base layer weights, restoring the original base weights. + """ + if not self.merged: + # No adapters to unmerge + return + + base_layer = self.get_base_layer() + origin_device = base_layer.weight0.device if self.is_grouped else base_layer.weight.device + if origin_device.type == 'cpu': + self.to(device=get_current_device()) + + for active_adapter in self.merged_adapters: + if active_adapter in self.lora_A.keys(): + if self.is_grouped: + orig_weights = [getattr(base_layer, f'weight{i}') for i in range(base_layer.num_gemms)] + else: + orig_weights = [base_layer.weight] + + delta_weights = self.get_delta_weights(active_adapter) + for orig_weight, delta_weight in zip(orig_weights, delta_weights): + # Subtract the delta weight to unmerge + orig_weight.data -= delta_weight + + # Clear the merged adapters list + self.merged_adapters = [] + + if origin_device.type == 'cpu': + self.to(device=origin_device) + def dispatch_megatron( target: torch.nn.Module, diff --git a/swift/megatron/utils/__init__.py b/swift/megatron/utils/__init__.py index 1afd505df0..e996e7362d 100644 --- a/swift/megatron/utils/__init__.py +++ b/swift/megatron/utils/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from .convert import convert_hf2mcore, convert_mcore2hf -from .patcher import patch_megatron_tokenizer -from .utils import (adapter_state_dict_context, copy_original_module_weight, prepare_mcore_model, - tuners_sharded_state_dict) +from .config import convert_hf_config +from .io_utils import LazyTensor, SafetensorLazyLoader, StreamingSafetensorSaver +from .patcher import patch_load_base_checkpoint, patch_merge_fn, patch_torch_dist_shard +from .utils import (adapter_state_dict_context, copy_original_module_weight, forward_step_helper, get_padding_to, + prepare_mcore_model, tuners_sharded_state_dict) diff --git a/swift/megatron/model/gpt/config.py b/swift/megatron/utils/config.py similarity index 50% rename from swift/megatron/model/gpt/config.py rename to swift/megatron/utils/config.py index e779a827b6..fba9dee26c 100644 --- a/swift/megatron/model/gpt/config.py +++ b/swift/megatron/utils/config.py @@ -1,11 +1,92 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from typing import Any, Dict -from ..config import convert_hf_config +from swift.utils import get_logger +logger = get_logger() +config_mapping = { + 'num_layers': ['num_hidden_layers'], + 'hidden_size': ['hidden_size'], + 'ffn_hidden_size': ['intermediate_size'], + 'num_attention_heads': ['num_attention_heads'], + 'num_query_groups': ['num_key_value_heads'], + 'max_position_embeddings': ['max_position_embeddings'], + 'norm_epsilon': ['rms_norm_eps'], + 'rotary_base': ['rope_theta'], + 'padded_vocab_size': ['vocab_size'], + 'attention_dropout': ['attention_dropout'], + 'untie_embeddings_and_output_weights': ['tie_word_embeddings'], + 'swiglu': ['hidden_act'], + 'add_qkv_bias': ['attention_bias', 'qkv_bias', 'use_bias'], + 'disable_bias_linear': ['mlp_bias'], + 'kv_channels': ['head_dim', 'v_head_dim'], + 'architectures': ['architectures'], + # moe + 'moe_ffn_hidden_size': ['moe_intermediate_size'], + 'moe_shared_expert_intermediate_size': ['shared_expert_intermediate_size'], + 'moe_router_topk': ['num_experts_per_tok', 'n_group', 'moe_topk', 'moe_k'], + 'num_experts': ['num_experts', 'n_routed_experts', 'moe_num_experts'], + 'moe_router_pre_softmax': ['norm_topk_prob'], + # deepseek + 'q_lora_rank': ['q_lora_rank'], + 'kv_lora_rank': ['kv_lora_rank'], + 'moe_router_score_function': ['scoring_func'], + 'qk_head_dim': ['qk_nope_head_dim'], + 'qk_pos_emb_head_dim': ['qk_rope_head_dim'], + 'moe_router_topk_scaling_factor': ['routed_scaling_factor'], + 'qk_layernorm': ['use_qk_norm'], + # qwen3_next + 'linear_num_value_heads': ['linear_num_value_heads'], + 'linear_num_key_heads': ['linear_num_key_heads'], + 'linear_key_head_dim': ['linear_key_head_dim'], + 'linear_value_head_dim': ['linear_value_head_dim'], + 'linear_conv_kernel_dim': ['linear_conv_kernel_dim'], + 'full_attention_interval': ['full_attention_interval'], + # other + 'original_max_position_embeddings': ['original_max_position_embeddings'], + 'partial_rotary_factor': ['partial_rotary_factor'], + 'first_k_dense_replace': ['first_k_dense_replace', 'moe_layer_start_index'], + 'n_shared_experts': ['n_shared_experts', 'num_shared_expert', 'moe_num_shared_experts'], +} -def convert_gpt_hf_config(config) -> Dict[str, Any]: - res = convert_hf_config(config) + +def _convert_config(config, _internal_call=False) -> Dict[str, Any]: + megatron_config = {} + for k, hf_keys in config_mapping.items(): + for hf_k in hf_keys: + if hasattr(config, hf_k): + hf_v = getattr(config, hf_k) + if hf_v is None: + continue + if k == 'rotary_base': + megatron_config[k] = int(hf_v) + elif k in {'untie_embeddings_and_output_weights', 'disable_bias_linear', 'moe_router_pre_softmax'}: + megatron_config[k] = not hf_v + elif k == 'swiglu': + if hf_v == 'silu': + megatron_config[k] = True + else: + if k == 'kv_lora_rank': + megatron_config['multi_latent_attention'] = True + elif k == 'architectures': + if _internal_call: + k = 'llm_architectures' + megatron_config[k] = hf_v + break + for key in ['text_config', 'llm_config', 'thinker_config']: + if hasattr(config, key): + megatron_config.update(convert_hf_config(getattr(config, key), _internal_call=True)) + # compat llama3 + if getattr(config, 'rope_scaling', None) is not None: + if isinstance(config.rope_scaling, int): + megatron_config['rope_scaling'] = {'factor': config.rope_scaling, 'type': 'linear'}, + elif isinstance(config.rope_scaling, dict): + megatron_config['rope_scaling'] = config.rope_scaling + return megatron_config + + +def convert_hf_config(config) -> Dict[str, Any]: + res = _convert_config(config) architectures = res.get('architectures') if isinstance(architectures, list) and architectures: architectures = architectures[0] diff --git a/swift/megatron/utils/io_utils.py b/swift/megatron/utils/io_utils.py new file mode 100644 index 0000000000..cbe9505efd --- /dev/null +++ b/swift/megatron/utils/io_utils.py @@ -0,0 +1,168 @@ +import os +from functools import partial +from typing import Literal + +import json +from safetensors.torch import safe_open, save_file + +from swift.utils import is_last_rank, is_master + + +class LazyTensor: + + def __init__(self, tensor=None, loader=None): + """You need to provide a tensor or loader""" + self.tensor = tensor + self.loader = loader + + def load(self): + if self.tensor is None: + return self.loader() + return self.tensor + + +class SafetensorLazyLoader: + + def __init__(self, hf_model_dir: str, is_peft_format: bool = False): + self.hf_model_dir = hf_model_dir + self.is_peft_format = is_peft_format + self._weight_map = {} + self._file_handles = {} + self._load_index() + + def _open_file(self, filename: str): + """Open a safetensors file if not already open.""" + if filename not in self._file_handles: + file_path = os.path.join(self.hf_model_dir, filename) + self._file_handles[filename] = safe_open(file_path, framework='pt') + return self._file_handles[filename] + + def _load_index(self): + """Load the model index file to get weight map.""" + index_path = os.path.join(self.hf_model_dir, 'model.safetensors.index.json') + + if os.path.exists(index_path): + with open(index_path, 'r') as f: + self._index_file = json.load(f) + self._weight_map = self._index_file.get('weight_map', {}) + else: + if self.is_peft_format: + safetensors_fname = 'adapter_model.safetensors' + else: + safetensors_fname = 'model.safetensors' + # Single file model + safetensors_file = os.path.join(self.hf_model_dir, safetensors_fname) + if os.path.exists(safetensors_file): + with safe_open(safetensors_file, framework='pt') as f: + for key in f.keys(): + self._weight_map[key] = safetensors_fname + + def get_state_dict(self): + res = {} + for k in self._weight_map.keys(): + res[k] = LazyTensor(loader=partial(self._load_tensor, key=k)) + return res + + def _load_tensor(self, key): + filename = self._weight_map[key] + file_handle = self._open_file(filename) + return file_handle.get_tensor(key) + + def close(self): + self._file_handles.clear() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + +class StreamingSafetensorSaver: + + def __init__( + self, + save_dir, + max_shard_size: str = '5GB', + save_rank: Literal['master', 'last'] = 'last', + is_peft_format: bool = False, + ) -> None: + self.save_dir = save_dir + if isinstance(max_shard_size, str): + if max_shard_size.endswith('GB'): + max_shard_size = int(max_shard_size[:-2]) + else: + raise ValueError(f'Invalid max_shard_size: {max_shard_size}') + self.max_shard_size = max_shard_size * 1000**3 + self.current_shard = {} + self.current_shard_size = 0 + self.total_size = 0 + self.shard_index = 1 + self.weight_map = {} + self.is_save_rank = is_last_rank() if save_rank == 'last' else is_master() + self.is_peft_format = is_peft_format + if self.is_save_rank: + os.makedirs(save_dir, exist_ok=True) + + def add_tensor(self, name, tensor): + if not self.is_save_rank: + return + tensor_size = tensor.numel() * tensor.element_size() + if (self.current_shard_size + tensor_size > self.max_shard_size and self.current_shard + and not self.is_peft_format): + self._save_current_shard() + + self.current_shard[name] = tensor.cpu().contiguous() + self.current_shard_size += tensor_size + + def _save_current_shard(self, shard_filename: str = None): + if not self.current_shard: + return + if shard_filename is None: + if self.is_peft_format: + shard_filename = 'adapter_model.safetensors' + else: + shard_filename = f'model-{self.shard_index:05d}-of-?????.safetensors' + shard_path = os.path.join(self.save_dir, shard_filename) + save_file(self.current_shard, str(shard_path)) + for key in self.current_shard.keys(): + self.weight_map[key] = shard_filename + + self.total_size += self.current_shard_size + self.current_shard = {} + self.current_shard_size = 0 + self.shard_index += 1 + + def finalize(self): + if not self.is_save_rank: + return + if self.current_shard: + self._save_current_shard() + if self.is_peft_format: + return + total_shards = self.shard_index - 1 + # rename `?????` + for i in range(1, total_shards + 1): + old_path = os.path.join(self.save_dir, f'model-{i:05d}-of-?????.safetensors') + if total_shards == 1: + new_name = 'model.safetensors' + else: + new_name = f'model-{i:05d}-of-{total_shards:05d}.safetensors' + new_path = os.path.join(self.save_dir, new_name) + if os.path.exists(old_path): + os.rename(old_path, new_path) + + if total_shards > 1: + updated_weight_map = {} + for key, filename in self.weight_map.items(): + new_filename = filename.replace('?????', f'{total_shards:05d}') + updated_weight_map[key] = new_filename + + self._save_index(updated_weight_map) + + def _save_index(self, weight_map): + index = {'metadata': {'total_size': self.total_size}, 'weight_map': weight_map} + + index_path = os.path.join(self.save_dir, 'model.safetensors.index.json') + with open(index_path, 'w') as f: + json.dump(index, f, indent=2) diff --git a/swift/megatron/utils/patcher.py b/swift/megatron/utils/patcher.py index 49dec85dea..3c35a42124 100644 --- a/swift/megatron/utils/patcher.py +++ b/swift/megatron/utils/patcher.py @@ -1,20 +1,16 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +from contextlib import contextmanager + +import torch +from megatron.core.dist_checkpointing.mapping import ShardedTensorFactory from megatron.core.dist_checkpointing.strategies.torch import TorchDistSaveShardedStrategy -from megatron.training import get_args, global_vars, initialize, training +from megatron.training import checkpointing from swift.utils import get_logger logger = get_logger() -def patch_megatron_tokenizer(tokenizer): - - def build_tokenizer(args): - return tokenizer - - global_vars.build_tokenizer = build_tokenizer - - def patch_torch_dist_shard(thread_count): __init__ = TorchDistSaveShardedStrategy.__init__ @@ -23,3 +19,46 @@ def __new_init__(*args, **kwargs): return __init__(*args, **kwargs) TorchDistSaveShardedStrategy.__init__ = __new_init__ + + +def patch_merge_fn(state_dict_model): + # https://github.com/NVIDIA/Megatron-LM/issues/1380 + + def sh_ten_merge_fn(sub_state_dict): + with torch.no_grad(): + shared_storage = sub_state_dict[0].untyped_storage() + if all(shared_storage.data_ptr() == tensor.untyped_storage().data_ptr() for tensor in sub_state_dict): + element_size = sub_state_dict[0].element_size() + total_numel = sum(tensor.numel() for tensor in sub_state_dict) + if shared_storage.nbytes() == total_numel * element_size: + dim_0 = sum(tensor.shape[0] for tensor in sub_state_dict) + shape = (dim_0, ) + sub_state_dict[0].shape[1:] + combined_tensor = torch.empty( + shape, dtype=sub_state_dict[0].dtype, + device=sub_state_dict[0].device).set_(shared_storage, 0, shape) + return combined_tensor + return torch.cat(sub_state_dict) + + for v in state_dict_model.values(): + if isinstance(v, ShardedTensorFactory) and 'apply_swiglu_sharded_factory' in v.merge_fn.__qualname__: + v.merge_fn = sh_ten_merge_fn + + +@contextmanager +def patch_load_base_checkpoint(): + origin__load_base_checkpoint = checkpointing._load_base_checkpoint + + def _load_base_checkpoint(*_args, **kwargs): + sharded_state_dict = kwargs.get('sharded_state_dict') + if sharded_state_dict is None: + return origin__load_base_checkpoint(*_args, **kwargs) + model_keys = [k for k in sharded_state_dict.keys() if k.startswith('model')] # compat vpp + for k in model_keys: + patch_merge_fn(sharded_state_dict[k]) + return origin__load_base_checkpoint(*_args, **kwargs) + + checkpointing._load_base_checkpoint = _load_base_checkpoint + try: + yield + finally: + checkpointing._load_base_checkpoint = origin__load_base_checkpoint diff --git a/swift/megatron/utils/utils.py b/swift/megatron/utils/utils.py index 6d4ae82228..48b9c01f52 100644 --- a/swift/megatron/utils/utils.py +++ b/swift/megatron/utils/utils.py @@ -3,9 +3,11 @@ from copy import deepcopy from typing import Optional, Tuple +import torch import torch.distributed as dist from megatron.core import mpu from megatron.core.extensions.transformer_engine import TEGroupedLinear, TELayerNormColumnParallelLinear, TELinear +from megatron.core.inference.communication_utils import recv_from_prev_pipeline_rank_, send_to_next_pipeline_rank from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding from megatron.core.transformer.moe.router import TopKRouter from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint, sharded_state_dict_default @@ -267,3 +269,45 @@ def copy_ref_adapter_weight(model, ref_adapter_name: str): sub_module = module.modules_to_save if 'default' in sub_module and ref_adapter_name in sub_module: sub_module[ref_adapter_name].load_state_dict(sub_module['default'].state_dict()) + + +def forward_step_helper(model, inputs, dtype=None): + args = get_args() + if mpu.is_pipeline_first_stage(): + micro_batch_size = 1 # use qkv_format 'thd' + seq_length = inputs['input_ids'].shape[1] + if args.sequence_parallel: + seq_length //= mpu.get_tensor_model_parallel_world_size() + recv_shape_buffer = torch.tensor([seq_length, micro_batch_size, args.hidden_size], + device=torch.cuda.current_device(), + dtype=torch.int64) + else: + recv_shape_buffer = torch.empty((3, ), device=torch.cuda.current_device(), dtype=torch.int64) + recv_from_prev_pipeline_rank_(recv_shape_buffer) + if not mpu.is_pipeline_last_stage(): + send_to_next_pipeline_rank(recv_shape_buffer) + shape = recv_shape_buffer.tolist() + + if not mpu.is_pipeline_first_stage(): + dtype = dtype or args.params_dtype + recv_buffer = torch.empty(shape, device=torch.cuda.current_device(), dtype=dtype) + recv_from_prev_pipeline_rank_(recv_buffer) + model.set_input_tensor(recv_buffer) + output_tensor = model(**inputs) + if not mpu.is_pipeline_last_stage(): + send_to_next_pipeline_rank(output_tensor) + output_tensor = None + + return output_tensor + + +def get_padding_to(args): + padding_to = None + if args.tensor_model_parallel_size > 1 and args.sequence_parallel: + padding_to = args.tensor_model_parallel_size + if args.context_parallel_size > 1: + padding_to = (padding_to or 1) * args.context_parallel_size + fp8_format = getattr(args, 'fp8_format', None) or getattr(args, 'fp8', None) + if fp8_format is not None: + padding_to = max((padding_to or 1) * 8, 16) + return padding_to diff --git a/tests/megatron/export/test_export.py b/tests/megatron/export/test_export.py new file mode 100644 index 0000000000..fa31528950 --- /dev/null +++ b/tests/megatron/export/test_export.py @@ -0,0 +1,27 @@ +from swift.megatron import MegatronExportArguments, megatron_export_main + + +def test_to_mcore(): + megatron_export_main( + MegatronExportArguments( + model='Qwen/Qwen2.5-7B-Instruct', + save='Qwen2.5-7B-Instruct-mcore', + to_mcore=True, + exist_ok=True, + tensor_model_parallel_size=2, + test_convert_precision=True)) + + +def test_to_hf(): + megatron_export_main( + MegatronExportArguments( + load='Qwen2.5-7B-Instruct-mcore', + to_hf=True, + exist_ok=True, + tensor_model_parallel_size=2, + test_convert_precision=True)) + + +if __name__ == '__main__': + # test_to_mcore() + test_to_hf() diff --git a/tests/megatron/test_save.py b/tests/megatron/test_save.py index cfc78182ae..c19b7792e6 100644 --- a/tests/megatron/test_save.py +++ b/tests/megatron/test_save.py @@ -12,7 +12,7 @@ def get_mg_model_tokenizer(): _, processor = get_model_tokenizer(model_id, load_model=False) megatron_model_meta = get_megatron_model_meta(processor.model_meta.model_type) model_info = processor.model_info - kwargs = megatron_model_meta.convert_hf_config(model_info.config) + kwargs = convert_hf_config(model_info.config) megatron_args = MegatronArguments( **kwargs, seq_length=1, @@ -22,7 +22,6 @@ def get_mg_model_tokenizer(): save='mcore-hf-test', no_load_optim=True, no_load_rng=True) - patch_megatron_tokenizer(processor) extra_args = megatron_args.parse_to_megatron() initialize_megatron(args_defaults=extra_args) mg_model = megatron_model_meta.model_provider() @@ -57,5 +56,4 @@ def test_save(): from swift.utils import set_default_ddp_config from swift.megatron.argument import MegatronArguments from swift.megatron.model import get_megatron_model_meta - from swift.megatron.utils import patch_megatron_tokenizer test_save()