Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 90 additions & 62 deletions megatron/core/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from torch.optim import SGD as CPUSGD
from torch.optim import AdamW as CPUAdam
from .muon import Muon
import os


try:
from transformer_engine.pytorch.optimizers import FusedAdam as Adam
Expand Down Expand Up @@ -179,81 +181,107 @@ def _get_param_groups(
assert (config_for_param, uses_default_config) == configs_map[key]
else:
configs_map[key] = (config_for_param, uses_default_config)

# Distributed checkpoint requires all ranks to have the same param groups,
# so we need to align the param groups across ranks, otherwise we may have
# runtime error when loading the checkpoint or numerical error when resuming training.
params_key = list(params_map.keys())
gathered_params_key = [None for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather_object(gathered_params_key, params_key)
for keys in gathered_params_key:
for key in keys:
if key not in params_key:
params_key.append(key)

# for muon optimizer
# For muon optimizer, we need to add the muon params key to the params_key
# so we need to align the param groups across ranks, otherwise we may have
# runtime error when loading the checkpoint or numerical error when resuming training.
muon_params_key = list(muon_params_map.keys())
gathered_muon_params_key = [None for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather_object(gathered_muon_params_key, muon_params_key)
for keys in gathered_muon_params_key:
for key in keys:
if key not in muon_params_key:
muon_params_key.append(key)

param_groups = []
for key in params_key:
wd_mult, is_expert_parallel, is_vision_model_param, _ = key
params = params_map[key] if key in params_map else []
config, uses_default_config = None, True
if key not in configs_map:
assert params == []
else:
config, uses_default_config = configs_map[key]
assert config is not None
if os.environ.get("ENABLE_SIMULATOR") == "1":
param_groups = []
params_key = list(params_map.keys())
for key in params_key:
wd_mult, is_expert_parallel, is_vision_model_param, _ = key
params = params_map[key] if key in params_map else []
config, uses_default_config = None, True
if key not in configs_map:
assert params == []
else:
config, uses_default_config = configs_map[key]
assert config is not None

# TODO: Remove "backwards compatible" fields below eventually.
param_group = {
'params': params,
'wd_mult': wd_mult, # For backwards compatibility.
'lr_mult': 1.0, # For backwards compatibility.
'is_expert_parallel': is_expert_parallel,
'is_decoupled_lr': False, # For backwards compatibility.
'default_config': uses_default_config,
'is_vision_model_param': is_vision_model_param,
}

# Stick relevant fields into param_group from config object.
if config is not None:
param_group['max_lr'] = config.lr if not is_vision_model_param else config.lr * config.vision_ration # NOTE(lizhiyu): change the ration here
param_group['min_lr'] = config.min_lr
# TODO: Add other relevant arguments (e.g., weight decay, optimizer)
# here as well.
param_groups.append(param_group)

for key in muon_params_key:
wd_mult, is_expert_parallel, _ = key
params = muon_params_map[key] if key in muon_params_map else []
config, uses_default_config = None, True
if key not in configs_map:
assert params == []
else:
config, uses_default_config = configs_map[key]
assert config is not None
param_group = {
'params': params,
'wd_mult': wd_mult, # For backwards compatibility.
'lr_mult': 1.0, # For backwards compatibility.
'is_expert_parallel': is_expert_parallel,
'is_decoupled_lr': False, # For backwards compatibility.
'default_config': uses_default_config,
'is_vision_model_param': is_vision_model_param,
}
param_groups.append(param_group)

else:
params_key = list(params_map.keys())
gathered_params_key = [None for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather_object(gathered_params_key, params_key)
for keys in gathered_params_key:
for key in keys:
if key not in params_key:
params_key.append(key)

# for muon optimizer
# For muon optimizer, we need to add the muon params key to the params_key
# so we need to align the param groups across ranks, otherwise we may have
# runtime error when loading the checkpoint or numerical error when resuming training.
muon_params_key = list(muon_params_map.keys())
gathered_muon_params_key = [None for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather_object(gathered_muon_params_key, muon_params_key)
for keys in gathered_muon_params_key:
for key in keys:
if key not in muon_params_key:
muon_params_key.append(key)

param_groups = []
for key in params_key:
wd_mult, is_expert_parallel, is_vision_model_param, _ = key
params = params_map[key] if key in params_map else []
config, uses_default_config = None, True
if key not in configs_map:
assert params == []
else:
config, uses_default_config = configs_map[key]
assert config is not None

param_groups.append(
{
# TODO: Remove "backwards compatible" fields below eventually.
param_group = {
'params': params,
'wd_mult': wd_mult, # For backwards compatibility.
'lr_mult': 1.0, # For backwards compatibility.
'is_expert_parallel': is_expert_parallel,
'is_decoupled_lr': False, # For backwards compatibility.
'default_config': uses_default_config,
'use_muon': True,
'is_vision_model_param': is_vision_model_param,
}
)

# Stick relevant fields into param_group from config object.
if config is not None:
param_group['max_lr'] = config.lr if not is_vision_model_param else config.lr * config.vision_ration # NOTE(lizhiyu): change the ration here
param_group['min_lr'] = config.min_lr
# TODO: Add other relevant arguments (e.g., weight decay, optimizer)
# here as well.
param_groups.append(param_group)

for key in muon_params_key:
wd_mult, is_expert_parallel, _ = key
params = muon_params_map[key] if key in muon_params_map else []
config, uses_default_config = None, True
if key not in configs_map:
assert params == []
else:
config, uses_default_config = configs_map[key]
assert config is not None

param_groups.append(
{
'params': params,
'wd_mult': wd_mult, # For backwards compatibility.
'lr_mult': 1.0, # For backwards compatibility.
'is_expert_parallel': is_expert_parallel,
'is_decoupled_lr': False, # For backwards compatibility.
'default_config': uses_default_config,
'use_muon': True,
}
)

return param_groups

Expand Down
87 changes: 66 additions & 21 deletions megatron/plugin/hetero/parallel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def __init__(self, args):
self._rank_infos = {}
self._physical_rank_to_logical_rank = {}
self._logical_rank_to_physical_rank = {}
self._enable_simulator = args.enable_simulator
self.build_rank_mapping()

def build_rank_mapping(self):
Expand All @@ -159,8 +160,16 @@ def build_rank_mapping(self):
all_rank_infos = [None] * world_size
cur_rank_info = {'rank': rank,
'device_type': self._hetero_current_device_type}
torch.distributed.all_gather_object(
all_rank_infos, cur_rank_info)
#torch.distributed.all_gather_object(
# all_rank_infos, cur_rank_info)
if self._enable_simulator:
for index, value in enumerate(all_rank_infos):
corresponding_rank_info = {'rank': index, 'device_type': self._hetero_current_device_type}
all_rank_infos[index] = corresponding_rank_info
else:
torch.distributed.all_gather_object(
all_rank_infos, cur_rank_info)

physical_ranks = []
for info in all_rank_infos:
self._rank_infos[info['rank']] = info
Expand Down Expand Up @@ -334,10 +343,17 @@ def build_process_group(
ranks = self._rank_mapper.to_physical_ranks(logical_ranks)
group = create_group(ranks, timeout=self._timeout, backend=self._distributed_backend, pg_options=pg_options, group_desc=group_name)
if gloo:
if create_gloo_process_groups:
group_gloo = create_group(ranks, timeout=self._timeout, backend="gloo", group_desc=group_name+"_gloo")
if self._args.enable_simulator:
if create_gloo_process_groups:
group_gloo = create_group(ranks, timeout=self._timeout, backend=self._distributed_backend, group_desc=group_name+"_gloo")
else:
group_gloo = None

else:
group_gloo = None
if create_gloo_process_groups:
group_gloo = create_group(ranks, timeout=self._timeout, backend="gloo", group_desc=group_name+"_gloo")
else:
group_gloo = None
self._all_group_ranks[group_name].append(ranks)
if self._rank in ranks:
self._group_ranks[group_name] = ranks
Expand Down Expand Up @@ -540,8 +556,8 @@ def get_ddp_config(self):
return self._ddp_config

def get_optimizer_config(self):
return (self._optimizer_config, self._optimizer_config_overrides)
return self._optimizer_config

def logical_coords_to_physical_ranks(self, coords, is_expert=False):
def _prefix_product(a: List[int], init=1) -> List[int]:
r = [init]
Expand Down Expand Up @@ -609,7 +625,6 @@ def __init__(self, args):
self._tranformer_config = None
self._ddp_config = None
self._optimizer_config = None
self._optimizer_config_overrides = None
self._dataset_config = None

self.build_config()
Expand Down Expand Up @@ -665,9 +680,19 @@ def build_all_process_meshes(self):
"rank": rank,
"process_mesh_idx": self._current_process_mesh_index,
}
torch.distributed.all_gather_object(
all_rank_to_process_mesh, cur_rank_to_process_mesh
)
#torch.distributed.all_gather_object(
# all_rank_to_process_mesh, cur_rank_to_process_mesh
#)

if self._args.enable_simulator:
for index, value in enumerate(all_rank_to_process_mesh):
corresponding_mesh_info = {'rank': index, 'process_mesh_idx': self._current_process_mesh_index}
all_rank_to_process_mesh[index] = corresponding_mesh_info
else:
torch.distributed.all_gather_object(
all_rank_to_process_mesh, cur_rank_to_process_mesh
)

for item in all_rank_to_process_mesh:
self._rank_to_process_mesh[item["rank"]] = self._process_meshes[
item["process_mesh_idx"]
Expand Down Expand Up @@ -783,7 +808,12 @@ def _backtrack(mesh_index, prev_rank, path, token = "pp", is_expert=False):
aggregated_ranks = [rank for ranks in path for rank in ranks]
self._global_all_group_ranks[group_name].append(aggregated_ranks)
# NOTE: "use_local_synchronization=True" works well in torhch <= 2.5, but it causes hang in torch >= 2.6
group = create_group(aggregated_ranks, timeout=self._timeout, use_local_synchronization=False, group_desc=group_name)
#group = create_group(aggregated_ranks, timeout=self._timeout, use_local_synchronization=False, group_desc=group_name)
if self._args.enable_simulator:
group = create_group(aggregated_ranks, timeout=self._timeout, use_local_synchronization=False, backend=self._args.distributed_backend, group_desc=group_name)
else:
group = create_group(aggregated_ranks, timeout=self._timeout, use_local_synchronization=False, group_desc=group_name)

if self._rank in aggregated_ranks:
self._global_process_groups[group_name].append(group)
self._global_group_ranks[group_name].append(aggregated_ranks)
Expand Down Expand Up @@ -839,13 +869,24 @@ def _backtrack(mesh_index, prev_rank, path, token = "pp", is_expert=False):
else:
embedding_ranks = ranks
position_embedding_ranks = ranks
group = create_group(embedding_ranks, timeout=self._timeout, use_local_synchronization=False, group_desc="embd")
#group = create_group(embedding_ranks, timeout=self._timeout, use_local_synchronization=False, group_desc="embd")
if self._args.enable_simulator:
group = create_group(embedding_ranks, timeout=self._timeout, use_local_synchronization=False, backend=self._args.distributed_backend, group_desc="embd")
else:
group = create_group(embedding_ranks, timeout=self._timeout, use_local_synchronization=False, group_desc="embd")


if self._rank in embedding_ranks and ("embd" not in self._global_group_ranks or embedding_ranks not in self._global_group_ranks["embd"]):
self._global_process_groups["embd"].append(group)
self._global_process_group_to_ranks[group] = embedding_ranks
self._global_group_ranks["embd"].append(embedding_ranks)

group = create_group(position_embedding_ranks, timeout=self._timeout, use_local_synchronization=False, group_desc="embd_pos")
#group = create_group(position_embedding_ranks, timeout=self._timeout, use_local_synchronization=False, group_desc="embd_pos")
if self._args.enable_simulator:
group = create_group(position_embedding_ranks, timeout=self._timeout, use_local_synchronization=False, backend=self._args.distributed_backend, group_desc="embd_pos")
else:
group = create_group(position_embedding_ranks, timeout=self._timeout, use_local_synchronization=False, group_desc="embd_pos")

if self._rank in position_embedding_ranks:
self._global_process_groups["embd_pos"].append(group)
self._global_process_group_to_ranks[group] = position_embedding_ranks
Expand Down Expand Up @@ -1661,22 +1702,26 @@ def _build_ddp_config(args):
kwargs[f.name] = getattr(args, f.name)
kwargs['grad_reduce_in_fp32'] = args.accumulate_allreduce_grads_in_fp32
kwargs['check_for_nan_in_grad'] = args.check_for_nan_in_loss_and_grad
if args.enable_simulator:
args.check_for_nan_in_loss_and_grad = False
kwargs['bucket_size'] = args.ddp_bucket_size
kwargs['average_in_collective'] = args.ddp_average_in_collective
ddp_config = DistributedDataParallelConfig(**kwargs)
return ddp_config

def _build_optimzer_config(args):
# Use specific optimizer config class based on optimizer type, matching Megatron-LM-FL behavior
from megatron.training.utils import get_megatron_optimizer_config
config, config_overrides = get_megatron_optimizer_config(args)
return config, config_overrides
from megatron.core.optimizer import OptimizerConfig
kwargs = {}
for f in dataclasses.fields(OptimizerConfig):
if hasattr(args, f.name):
kwargs[f.name] = getattr(args, f.name)
return OptimizerConfig(**kwargs)

def _build_dataset_config(args):
from megatron.core.datasets.gpt_dataset import GPTDatasetConfig
from megatron.training import get_tokenizer
from megatron.core.datasets.utils import get_blend_from_list
from megatron.training.datasets.sft_dataset_fs import SFTDatasetConfig
from flagscale.train.datasets.sft_dataset import SFTDatasetConfig

if args.apply_sft_dataset_separated_loss_mask_if_existed:
tokenizer = get_tokenizer()
Expand Down Expand Up @@ -1729,7 +1774,7 @@ def _build_dataset_config(args):
from megatron.training.arguments import core_transformer_config_from_args
self._transformer_config = core_transformer_config_from_args(self._args)
self._ddp_config = _build_ddp_config(self._args)
self._optimizer_config, self._optimizer_config_overrides = _build_optimzer_config(self._args)
self._optimizer_config = _build_optimzer_config(self._args)
self._dataset_config = _build_dataset_config(self._args)

def get_transformer_config(self):
Expand All @@ -1739,7 +1784,7 @@ def get_ddp_config(self):
return self._ddp_config

def get_optimizer_config(self):
return (self._optimizer_config, self._optimizer_config_overrides)
return self._optimizer_config

def get_dataset_config(self):
return self._dataset_config