From b6bb458ed3ea0f8ae57910343a2799de0682375d Mon Sep 17 00:00:00 2001 From: shuailong616 <452509829@qq.com> Date: Thu, 8 Jan 2026 18:11:52 +0800 Subject: [PATCH] support single_process_simulator for flagscale --- megatron/core/optimizer/__init__.py | 152 ++++++++++++--------- megatron/plugin/hetero/parallel_context.py | 87 +++++++++--- 2 files changed, 156 insertions(+), 83 deletions(-) diff --git a/megatron/core/optimizer/__init__.py b/megatron/core/optimizer/__init__.py index 0276c5e0e70..48a9e6772de 100644 --- a/megatron/core/optimizer/__init__.py +++ b/megatron/core/optimizer/__init__.py @@ -9,6 +9,8 @@ from torch.optim import SGD as CPUSGD from torch.optim import AdamW as CPUAdam from .muon import Muon +import os + try: from transformer_engine.pytorch.optimizers import FusedAdam as Adam @@ -179,81 +181,107 @@ def _get_param_groups( assert (config_for_param, uses_default_config) == configs_map[key] else: configs_map[key] = (config_for_param, uses_default_config) - + # Distributed checkpoint requires all ranks to have the same param groups, # so we need to align the param groups across ranks, otherwise we may have # runtime error when loading the checkpoint or numerical error when resuming training. - params_key = list(params_map.keys()) - gathered_params_key = [None for _ in range(torch.distributed.get_world_size())] - torch.distributed.all_gather_object(gathered_params_key, params_key) - for keys in gathered_params_key: - for key in keys: - if key not in params_key: - params_key.append(key) - - # for muon optimizer - # For muon optimizer, we need to add the muon params key to the params_key - # so we need to align the param groups across ranks, otherwise we may have - # runtime error when loading the checkpoint or numerical error when resuming training. - muon_params_key = list(muon_params_map.keys()) - gathered_muon_params_key = [None for _ in range(torch.distributed.get_world_size())] - torch.distributed.all_gather_object(gathered_muon_params_key, muon_params_key) - for keys in gathered_muon_params_key: - for key in keys: - if key not in muon_params_key: - muon_params_key.append(key) - - param_groups = [] - for key in params_key: - wd_mult, is_expert_parallel, is_vision_model_param, _ = key - params = params_map[key] if key in params_map else [] - config, uses_default_config = None, True - if key not in configs_map: - assert params == [] - else: - config, uses_default_config = configs_map[key] - assert config is not None + if os.environ.get("ENABLE_SIMULATOR") == "1": + param_groups = [] + params_key = list(params_map.keys()) + for key in params_key: + wd_mult, is_expert_parallel, is_vision_model_param, _ = key + params = params_map[key] if key in params_map else [] + config, uses_default_config = None, True + if key not in configs_map: + assert params == [] + else: + config, uses_default_config = configs_map[key] + assert config is not None # TODO: Remove "backwards compatible" fields below eventually. - param_group = { - 'params': params, - 'wd_mult': wd_mult, # For backwards compatibility. - 'lr_mult': 1.0, # For backwards compatibility. - 'is_expert_parallel': is_expert_parallel, - 'is_decoupled_lr': False, # For backwards compatibility. - 'default_config': uses_default_config, - 'is_vision_model_param': is_vision_model_param, - } - - # Stick relevant fields into param_group from config object. - if config is not None: - param_group['max_lr'] = config.lr if not is_vision_model_param else config.lr * config.vision_ration # NOTE(lizhiyu): change the ration here - param_group['min_lr'] = config.min_lr - # TODO: Add other relevant arguments (e.g., weight decay, optimizer) - # here as well. - param_groups.append(param_group) - - for key in muon_params_key: - wd_mult, is_expert_parallel, _ = key - params = muon_params_map[key] if key in muon_params_map else [] - config, uses_default_config = None, True - if key not in configs_map: - assert params == [] - else: - config, uses_default_config = configs_map[key] - assert config is not None + param_group = { + 'params': params, + 'wd_mult': wd_mult, # For backwards compatibility. + 'lr_mult': 1.0, # For backwards compatibility. + 'is_expert_parallel': is_expert_parallel, + 'is_decoupled_lr': False, # For backwards compatibility. + 'default_config': uses_default_config, + 'is_vision_model_param': is_vision_model_param, + } + param_groups.append(param_group) + + else: + params_key = list(params_map.keys()) + gathered_params_key = [None for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather_object(gathered_params_key, params_key) + for keys in gathered_params_key: + for key in keys: + if key not in params_key: + params_key.append(key) + + # for muon optimizer + # For muon optimizer, we need to add the muon params key to the params_key + # so we need to align the param groups across ranks, otherwise we may have + # runtime error when loading the checkpoint or numerical error when resuming training. + muon_params_key = list(muon_params_map.keys()) + gathered_muon_params_key = [None for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather_object(gathered_muon_params_key, muon_params_key) + for keys in gathered_muon_params_key: + for key in keys: + if key not in muon_params_key: + muon_params_key.append(key) + + param_groups = [] + for key in params_key: + wd_mult, is_expert_parallel, is_vision_model_param, _ = key + params = params_map[key] if key in params_map else [] + config, uses_default_config = None, True + if key not in configs_map: + assert params == [] + else: + config, uses_default_config = configs_map[key] + assert config is not None - param_groups.append( - { + # TODO: Remove "backwards compatible" fields below eventually. + param_group = { 'params': params, 'wd_mult': wd_mult, # For backwards compatibility. 'lr_mult': 1.0, # For backwards compatibility. 'is_expert_parallel': is_expert_parallel, 'is_decoupled_lr': False, # For backwards compatibility. 'default_config': uses_default_config, - 'use_muon': True, + 'is_vision_model_param': is_vision_model_param, } - ) + + # Stick relevant fields into param_group from config object. + if config is not None: + param_group['max_lr'] = config.lr if not is_vision_model_param else config.lr * config.vision_ration # NOTE(lizhiyu): change the ration here + param_group['min_lr'] = config.min_lr + # TODO: Add other relevant arguments (e.g., weight decay, optimizer) + # here as well. + param_groups.append(param_group) + + for key in muon_params_key: + wd_mult, is_expert_parallel, _ = key + params = muon_params_map[key] if key in muon_params_map else [] + config, uses_default_config = None, True + if key not in configs_map: + assert params == [] + else: + config, uses_default_config = configs_map[key] + assert config is not None + + param_groups.append( + { + 'params': params, + 'wd_mult': wd_mult, # For backwards compatibility. + 'lr_mult': 1.0, # For backwards compatibility. + 'is_expert_parallel': is_expert_parallel, + 'is_decoupled_lr': False, # For backwards compatibility. + 'default_config': uses_default_config, + 'use_muon': True, + } + ) return param_groups diff --git a/megatron/plugin/hetero/parallel_context.py b/megatron/plugin/hetero/parallel_context.py index 4f0c921d553..26f35bb6805 100644 --- a/megatron/plugin/hetero/parallel_context.py +++ b/megatron/plugin/hetero/parallel_context.py @@ -150,6 +150,7 @@ def __init__(self, args): self._rank_infos = {} self._physical_rank_to_logical_rank = {} self._logical_rank_to_physical_rank = {} + self._enable_simulator = args.enable_simulator self.build_rank_mapping() def build_rank_mapping(self): @@ -159,8 +160,16 @@ def build_rank_mapping(self): all_rank_infos = [None] * world_size cur_rank_info = {'rank': rank, 'device_type': self._hetero_current_device_type} - torch.distributed.all_gather_object( - all_rank_infos, cur_rank_info) + #torch.distributed.all_gather_object( + # all_rank_infos, cur_rank_info) + if self._enable_simulator: + for index, value in enumerate(all_rank_infos): + corresponding_rank_info = {'rank': index, 'device_type': self._hetero_current_device_type} + all_rank_infos[index] = corresponding_rank_info + else: + torch.distributed.all_gather_object( + all_rank_infos, cur_rank_info) + physical_ranks = [] for info in all_rank_infos: self._rank_infos[info['rank']] = info @@ -334,10 +343,17 @@ def build_process_group( ranks = self._rank_mapper.to_physical_ranks(logical_ranks) group = create_group(ranks, timeout=self._timeout, backend=self._distributed_backend, pg_options=pg_options, group_desc=group_name) if gloo: - if create_gloo_process_groups: - group_gloo = create_group(ranks, timeout=self._timeout, backend="gloo", group_desc=group_name+"_gloo") + if self._args.enable_simulator: + if create_gloo_process_groups: + group_gloo = create_group(ranks, timeout=self._timeout, backend=self._distributed_backend, group_desc=group_name+"_gloo") + else: + group_gloo = None + else: - group_gloo = None + if create_gloo_process_groups: + group_gloo = create_group(ranks, timeout=self._timeout, backend="gloo", group_desc=group_name+"_gloo") + else: + group_gloo = None self._all_group_ranks[group_name].append(ranks) if self._rank in ranks: self._group_ranks[group_name] = ranks @@ -540,8 +556,8 @@ def get_ddp_config(self): return self._ddp_config def get_optimizer_config(self): - return (self._optimizer_config, self._optimizer_config_overrides) - + return self._optimizer_config + def logical_coords_to_physical_ranks(self, coords, is_expert=False): def _prefix_product(a: List[int], init=1) -> List[int]: r = [init] @@ -609,7 +625,6 @@ def __init__(self, args): self._tranformer_config = None self._ddp_config = None self._optimizer_config = None - self._optimizer_config_overrides = None self._dataset_config = None self.build_config() @@ -665,9 +680,19 @@ def build_all_process_meshes(self): "rank": rank, "process_mesh_idx": self._current_process_mesh_index, } - torch.distributed.all_gather_object( - all_rank_to_process_mesh, cur_rank_to_process_mesh - ) + #torch.distributed.all_gather_object( + # all_rank_to_process_mesh, cur_rank_to_process_mesh + #) + + if self._args.enable_simulator: + for index, value in enumerate(all_rank_to_process_mesh): + corresponding_mesh_info = {'rank': index, 'process_mesh_idx': self._current_process_mesh_index} + all_rank_to_process_mesh[index] = corresponding_mesh_info + else: + torch.distributed.all_gather_object( + all_rank_to_process_mesh, cur_rank_to_process_mesh + ) + for item in all_rank_to_process_mesh: self._rank_to_process_mesh[item["rank"]] = self._process_meshes[ item["process_mesh_idx"] @@ -783,7 +808,12 @@ def _backtrack(mesh_index, prev_rank, path, token = "pp", is_expert=False): aggregated_ranks = [rank for ranks in path for rank in ranks] self._global_all_group_ranks[group_name].append(aggregated_ranks) # NOTE: "use_local_synchronization=True" works well in torhch <= 2.5, but it causes hang in torch >= 2.6 - group = create_group(aggregated_ranks, timeout=self._timeout, use_local_synchronization=False, group_desc=group_name) + #group = create_group(aggregated_ranks, timeout=self._timeout, use_local_synchronization=False, group_desc=group_name) + if self._args.enable_simulator: + group = create_group(aggregated_ranks, timeout=self._timeout, use_local_synchronization=False, backend=self._args.distributed_backend, group_desc=group_name) + else: + group = create_group(aggregated_ranks, timeout=self._timeout, use_local_synchronization=False, group_desc=group_name) + if self._rank in aggregated_ranks: self._global_process_groups[group_name].append(group) self._global_group_ranks[group_name].append(aggregated_ranks) @@ -839,13 +869,24 @@ def _backtrack(mesh_index, prev_rank, path, token = "pp", is_expert=False): else: embedding_ranks = ranks position_embedding_ranks = ranks - group = create_group(embedding_ranks, timeout=self._timeout, use_local_synchronization=False, group_desc="embd") + #group = create_group(embedding_ranks, timeout=self._timeout, use_local_synchronization=False, group_desc="embd") + if self._args.enable_simulator: + group = create_group(embedding_ranks, timeout=self._timeout, use_local_synchronization=False, backend=self._args.distributed_backend, group_desc="embd") + else: + group = create_group(embedding_ranks, timeout=self._timeout, use_local_synchronization=False, group_desc="embd") + + if self._rank in embedding_ranks and ("embd" not in self._global_group_ranks or embedding_ranks not in self._global_group_ranks["embd"]): self._global_process_groups["embd"].append(group) self._global_process_group_to_ranks[group] = embedding_ranks self._global_group_ranks["embd"].append(embedding_ranks) - group = create_group(position_embedding_ranks, timeout=self._timeout, use_local_synchronization=False, group_desc="embd_pos") + #group = create_group(position_embedding_ranks, timeout=self._timeout, use_local_synchronization=False, group_desc="embd_pos") + if self._args.enable_simulator: + group = create_group(position_embedding_ranks, timeout=self._timeout, use_local_synchronization=False, backend=self._args.distributed_backend, group_desc="embd_pos") + else: + group = create_group(position_embedding_ranks, timeout=self._timeout, use_local_synchronization=False, group_desc="embd_pos") + if self._rank in position_embedding_ranks: self._global_process_groups["embd_pos"].append(group) self._global_process_group_to_ranks[group] = position_embedding_ranks @@ -1661,22 +1702,26 @@ def _build_ddp_config(args): kwargs[f.name] = getattr(args, f.name) kwargs['grad_reduce_in_fp32'] = args.accumulate_allreduce_grads_in_fp32 kwargs['check_for_nan_in_grad'] = args.check_for_nan_in_loss_and_grad + if args.enable_simulator: + args.check_for_nan_in_loss_and_grad = False kwargs['bucket_size'] = args.ddp_bucket_size kwargs['average_in_collective'] = args.ddp_average_in_collective ddp_config = DistributedDataParallelConfig(**kwargs) return ddp_config def _build_optimzer_config(args): - # Use specific optimizer config class based on optimizer type, matching Megatron-LM-FL behavior - from megatron.training.utils import get_megatron_optimizer_config - config, config_overrides = get_megatron_optimizer_config(args) - return config, config_overrides + from megatron.core.optimizer import OptimizerConfig + kwargs = {} + for f in dataclasses.fields(OptimizerConfig): + if hasattr(args, f.name): + kwargs[f.name] = getattr(args, f.name) + return OptimizerConfig(**kwargs) def _build_dataset_config(args): from megatron.core.datasets.gpt_dataset import GPTDatasetConfig from megatron.training import get_tokenizer from megatron.core.datasets.utils import get_blend_from_list - from megatron.training.datasets.sft_dataset_fs import SFTDatasetConfig + from flagscale.train.datasets.sft_dataset import SFTDatasetConfig if args.apply_sft_dataset_separated_loss_mask_if_existed: tokenizer = get_tokenizer() @@ -1729,7 +1774,7 @@ def _build_dataset_config(args): from megatron.training.arguments import core_transformer_config_from_args self._transformer_config = core_transformer_config_from_args(self._args) self._ddp_config = _build_ddp_config(self._args) - self._optimizer_config, self._optimizer_config_overrides = _build_optimzer_config(self._args) + self._optimizer_config = _build_optimzer_config(self._args) self._dataset_config = _build_dataset_config(self._args) def get_transformer_config(self): @@ -1739,7 +1784,7 @@ def get_ddp_config(self): return self._ddp_config def get_optimizer_config(self): - return (self._optimizer_config, self._optimizer_config_overrides) + return self._optimizer_config def get_dataset_config(self): return self._dataset_config