diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index 367ba524a..da3b33c7b 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -25,7 +25,6 @@ convert_attn_args_to_kwargs, convert_attn_kwargs_to_args, internlm1_mha_pre_load_convert, - internlm1_mha_save_convert, ) from internlm.solver.activation_checkpoint import activation_checkpoint from internlm.utils.logger import get_logger @@ -122,7 +121,7 @@ def __init__( ) # Compatible with the name of internlm1 Wqkv linear layer - self.mixer.register_checkpoint_compatibility_hooks(internlm1_mha_pre_load_convert, internlm1_mha_save_convert) + self.mixer.register_checkpoint_compatibility_hooks(internlm1_mha_pre_load_convert) self.dropout1 = nn.Dropout(drop_rate) self.dropout2 = nn.Dropout(drop_rate) diff --git a/internlm/model/modeling_moe.py b/internlm/model/modeling_moe.py index ed32ca03c..1b7092c07 100644 --- a/internlm/model/modeling_moe.py +++ b/internlm/model/modeling_moe.py @@ -22,7 +22,6 @@ convert_attn_args_to_kwargs, convert_attn_kwargs_to_args, internlm1_mha_pre_load_convert, - internlm1_mha_save_convert, ) from internlm.solver.activation_checkpoint import activation_checkpoint from internlm.utils.logger import get_logger @@ -112,7 +111,7 @@ def __init__( ) # Compatible with the name of internlm1 Wqkv linear layer - self.mixer.register_checkpoint_compatibility_hooks(internlm1_mha_pre_load_convert, internlm1_mha_save_convert) + self.mixer.register_checkpoint_compatibility_hooks(internlm1_mha_pre_load_convert) self.dropout1 = nn.Dropout(drop_rate) self.dropout2 = nn.Dropout(drop_rate) diff --git a/internlm/model/modules/linear.py b/internlm/model/modules/linear.py index 0dc9d3072..3ae6a5a10 100644 --- a/internlm/model/modules/linear.py +++ b/internlm/model/modules/linear.py @@ -649,12 +649,19 @@ def __init__( self.tp_dim = 1 else: super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype) + self.tp_dim = -1 self.complete_size = [out_features, in_features] setattr(self.weight, "offset", self.offset) setattr(self.weight, "complete_size", [out_features, in_features]) setattr(self.weight, "tp_dim", self.tp_dim) + if bias: + if self.tp_dim == 0: + setattr(self.bias, "tp_dim", 0) + else: + setattr(self.bias, "tp_dim", -1) + def forward(self, input: torch.Tensor, batch_sizes: torch.Tensor = None) -> torch.Tensor: # pylint: disable=W0622 _class_name = self.__class__.__name__ assert self._communicator is not None, f"{_class_name} should register with a communicator first." @@ -904,16 +911,24 @@ def __init__( # pylint: disable=W0231, W0233 self.weight = nn.Parameter( torch.empty(num_groups, in_features, local_multiple * multiple_of, device=device, dtype=dtype) ) + self.tp_dim = 2 + assert self.weight.shape[self.tp_dim] != out_features elif split_mode == "row": self.weight = nn.Parameter( torch.empty(num_groups, local_multiple * multiple_of, out_features, device=device, dtype=dtype) ) + self.tp_dim = 1 + assert self.weight.shape[self.tp_dim] != in_features elif split_mode == "weight": self.weight = nn.Parameter( torch.empty(local_multiple * multiple_of, out_features, device=device, dtype=dtype) ) + self.tp_dim = 0 else: # none self.weight = nn.Parameter(torch.empty(num_groups, in_features, out_features, device=device, dtype=dtype)) + self.tp_dim = -1 + + setattr(self.weight, "tp_dim", self.tp_dim) self.register_parameter("bias", None) torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) diff --git a/internlm/model/modules/mha.py b/internlm/model/modules/mha.py index 5c8a60b3b..dd4a3d2ae 100644 --- a/internlm/model/modules/mha.py +++ b/internlm/model/modules/mha.py @@ -160,7 +160,9 @@ def register_checkpoint_compatibility_hooks( # hoping that model developers will make good use of it when adapting. # Is this interface already meeting all reasonable requirements? self._register_load_state_dict_pre_hook(pre_load_hook, with_module=True) - self._register_state_dict_hook(pre_save_hook) + if pre_save_hook is not None: + logger.warning("pre_save_hook may destory universal_ckpt") + self._register_state_dict_hook(pre_save_hook) def forward(self, x, inference_params=None, **kwargs): if inference_params is None: @@ -471,7 +473,9 @@ def register_checkpoint_compatibility_hooks( # hoping that model developers will make good use of it when adapting. # Is this interface already meeting all reasonable requirements? self._register_load_state_dict_pre_hook(pre_load_hook, with_module=True) - self._register_state_dict_hook(pre_save_hook) + if pre_save_hook is not None: + logger.warning("pre_save_hook may destory universal_ckpt") + self._register_state_dict_hook(pre_save_hook) def forward(self, x, inference_params=None, **kwargs): if inference_params is None: diff --git a/internlm/model/moe/base_layer.py b/internlm/model/moe/base_layer.py index 7811e056d..eafbf7612 100644 --- a/internlm/model/moe/base_layer.py +++ b/internlm/model/moe/base_layer.py @@ -22,7 +22,7 @@ def __init__( ) -> None: super().__init__() # for elastic expert paralle, experts may have multiple groups - expert_group_name = f"moe_ep_size_{ep_size}" + expert_group_name = "moe_ep_group" if expert_group_name not in gpc.expert_parallel_group_names: gpc.expert_parallel_group_names.append(expert_group_name) self.gate = gate diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 6620dda2e..3ca392542 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -149,11 +149,16 @@ def __init__( assert self._param_bcast_sync_handler is not None self._isp_communicator = isp_communicator - self.meta_for_zero = None + self.meta_for_zero = {"base_groups": {}} + self.meta_for_moe = {"base_groups": {}} + self.moe_group = [] # iterate over the param group in the optimizer # partition these param groups for data parallel training # and add buffers to parameter store for future access for group_id, param_group in enumerate(self.optim.param_groups): + if "moe" in param_group and param_group["moe"]: + self.moe_group.append(group_id) + group_params = param_group["params"] # set the dtype for each param group @@ -166,8 +171,6 @@ def __init__( self._zero_local_rank.append(gpc.get_local_rank(zero_mode)) self._zero_world_size.append(gpc.get_world_size(zero_mode)) - if gpc.config.ckpt.need_metadata and self.meta_for_zero is None: - self.meta_for_zero = [{} for _ in range(gpc.get_world_size(zero_mode))] # TODO _broadcast_parallel_mode is not only used in broadcast, maybe can change its name self._broadcast_parallel_mode.append(zero_mode) @@ -232,6 +235,12 @@ def __init__( # managed by this data parallel rank param_group["params"] = [fp32_flat_current_rank] + base_groups = self.optim.state_dict()["param_groups"][group_id]["params"] + if group_id in self.moe_group: + self.meta_for_moe["base_groups"][group_id] = base_groups + else: + self.meta_for_zero["base_groups"][group_id] = base_groups + # set reduction state for param in self._fp16_param_groups[group_id]: self._param_store.set_param_reduction_state(param, False) @@ -285,20 +294,39 @@ def _partition_param_list(self, group_id, param_group): numel_per_rank[rank_to_go] += param.numel() if gpc.config.ckpt.need_metadata: - if group_id not in self.meta_for_zero[rank_to_go]: - self.meta_for_zero[rank_to_go][group_id] = {} - - from internlm.train.pipeline import map_fqn_local_to_global - - global_fqn = map_fqn_local_to_global[param.fqn] if param.fqn in map_fqn_local_to_global else param.fqn - self.meta_for_zero[rank_to_go][group_id][global_fqn] = { - "tp_dim": getattr(param, "tp_dim", -1), - "pp": gpc.get_local_rank(ParallelMode.PIPELINE), - "zero1": rank_to_go, - "fqn": param.fqn, - "shape": param.shape, - "group_id": group_id, - } + if rank_to_go == self.zero_local_rank[group_id]: + + from internlm.train.pipeline import map_fqn_local_to_global + + global_fqn = ( + map_fqn_local_to_global[param.fqn] if param.fqn in map_fqn_local_to_global else param.fqn + ) + if group_id in self.moe_group: + if group_id not in self.meta_for_moe: + self.meta_for_moe[group_id] = {} + tp_mode = ParallelMode.WEIGHT if is_using_isp() else ParallelMode.TENSOR + self.meta_for_moe[group_id][global_fqn] = { + "tp_dim": getattr(param, "tp_dim", -1), + "tp": gpc.get_local_rank(tp_mode), + "pp": gpc.get_local_rank(ParallelMode.PIPELINE), + "ep": gpc.get_local_rank(ParallelMode.EXPERT), + "edp": gpc.get_local_rank(ParallelMode.EXPERT_DATA), + "zero1": rank_to_go, + "fqn": param.fqn, + "shape": param.shape, + "group_id": group_id, + } + else: + if group_id not in self.meta_for_zero: + self.meta_for_zero[group_id] = {} + self.meta_for_zero[group_id][global_fqn] = { + "tp_dim": getattr(param, "tp_dim", -1), + "pp": gpc.get_local_rank(ParallelMode.PIPELINE), + "zero1": rank_to_go, + "fqn": param.fqn, + "shape": param.shape, + "group_id": group_id, + } # check whether any rank is not assigned to parameters. for rank, params in enumerate(params_per_rank): diff --git a/internlm/train/pipeline.py b/internlm/train/pipeline.py index 0586cafc7..b5149715c 100644 --- a/internlm/train/pipeline.py +++ b/internlm/train/pipeline.py @@ -6,6 +6,7 @@ import itertools import math import os +import re import time from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar, Union @@ -165,152 +166,149 @@ def set_param_unique_tracking_name(model): if isinstance(children, nn.ModuleList): for idx, block in enumerate(children): for name, child in block.named_modules(): - if name == "": + if name == "" or not hasattr(child, "weight"): continue - full_name = f"{chunk_id}.{idx}.{name}" - name_parts = f"{full_name}.weight".split(".", 2) - # global_id for pipeline parallel case - global_id = model.first_layer + idx - local_fqn = f"{children_name}." + ".".join(name_parts[1:]) - global_fqn = f"{children_name}.{global_id}." + ".".join(name_parts[2:]) + full_name = f"{chunk_id}.{idx}.{name}" # noqa: F841 # pylint: disable=W0612 + if "wrapped_experts" in name or getattr(child.weight, "is_expert", False): + match = re.search(r"wrapped_experts\.(\d+)", name) + ep_idx = int(match.group(1)) + before_idx = name[: match.start(1)] + after_idx = name[match.end(1) :] - if isinstance(child, (ParallelLinearWithCommExt)): - setattr( - child.weight, - "tracking_name", - f"{full_name}.weight", - ) - if child.bias is not None: - setattr( - child.bias, - "tracking_name", - f"{full_name}.bias", - ) + ep_rank = gpc.get_local_rank(ParallelMode.EXPERT) + ep_size = gpc.get_world_size(ParallelMode.EXPERT) + num_experts = gpc.config.model.num_experts - setattr( - child.weight, - "fqn", - f"{local_fqn}", - ) - if child.bias is not None: - setattr( - child.bias, - "fqn", - f"{local_fqn}", - ) - - assert hasattr(child, "offset"), f"{child}" - map_fqn_local_to_global[local_fqn] = global_fqn - map_fqn_global_to_local[global_fqn] = local_fqn + global_ep_idx = int(ep_idx + ep_rank * (num_experts // ep_size)) + global_name = before_idx + str(global_ep_idx) + after_idx - assert global_fqn not in map_layer_attr, f"{map_layer_attr} exists" - map_layer_attr[global_fqn] = { - "offset": getattr(child, "offset", [0] * len(child.weight.size())), - "complete_size": getattr(child, "complete_size", list(child.weight.size())), - } + local_fqn = f"{children_name}.{idx}.{global_name}.weight" + global_fqn = f"{children_name}.{model.first_layer + idx}.{global_name}.weight" + else: + local_fqn = f"{children_name}.{idx}.{name}.weight" + global_fqn = f"{children_name}.{model.first_layer + idx}.{name}.weight" - elif isinstance(child, (RMSNorm)): - map_fqn_local_to_global[local_fqn] = global_fqn - map_fqn_global_to_local[global_fqn] = local_fqn + setattr( + child.weight, + "fqn", + f"{local_fqn}", + ) + + map_fqn_local_to_global[local_fqn] = global_fqn + map_fqn_global_to_local[global_fqn] = local_fqn + + assert global_fqn not in map_layer_attr, f"{global_fqn} exists" + map_layer_attr[global_fqn] = { + "offset": getattr(child, "offset", [0] * len(child.weight.size())), + "complete_size": getattr(child, "complete_size", list(child.weight.size())), + } + + if hasattr(child, "bias") and child.bias is not None: + local_fqn = local_fqn.replace("weight", "bias") + global_fqn = global_fqn.replace("weight", "bias") setattr( - child.weight, + child.bias, "fqn", f"{local_fqn}", ) - map_layer_attr[global_fqn] = { - "offset": getattr(child, "offset", [0] * len(child.weight.size())), - "complete_size": getattr(child, "complete_size", list(child.weight.size())), - } + map_fqn_local_to_global[local_fqn] = global_fqn + map_fqn_global_to_local[global_fqn] = local_fqn else: - full_name = f"{chunk_id}.{children_name}" local_fqn = f"{children_name}.weight" assert getattr(children, "bias", None) is None - if isinstance(children, Embedding1D): - setattr( - children.weight, - "tracking_name", - f"{chunk_id}_embeddings.weight", - ) - assert local_fqn not in map_layer_attr, f"{map_layer_attr} exists" - else: - setattr( - children.weight, - "tracking_name", - f"{full_name}.weight", - ) - assert local_fqn not in map_layer_attr, f"{map_layer_attr} exists" - setattr( children.weight, "fqn", f"{local_fqn}", ) - if getattr(children, "bias", None) is not None: - if children.bias is not None: - setattr( - children.bias, - "fqn", - f"{local_fqn}", - ) - map_layer_attr[local_fqn] = { "offset": getattr(children, "offset", [0] * len(children.weight.size())), "complete_size": getattr(children, "complete_size", list(children.weight.size())), } -def generate_meta_data(optimizer): - if not gpc.config.ckpt.need_metadata: - return - - if gpc.get_world_size(ParallelMode.PIPELINE) > 1: - assert optimizer.meta_for_zero is not None - dst = gpc.get_ranks_in_group(ParallelMode.PIPELINE)[0] +def gather_meta(metadata, parallelmode, moe_no_tp=False): + if not moe_no_tp and gpc.get_world_size(parallelmode) > 1: + dst = gpc.get_ranks_in_group(parallelmode)[0] if gpc.get_global_rank() == dst: - output = [None for _ in range(gpc.get_world_size(ParallelMode.PIPELINE))] + output = [None for _ in range(gpc.get_world_size(parallelmode))] else: output = None - dist.gather_object(optimizer.meta_for_zero, output, dst=dst, group=gpc.get_group(ParallelMode.PIPELINE)) - pp_gather_output = output - + dist.gather_object(metadata, output, dst=dst, group=gpc.get_group(parallelmode)) + res = output else: - pp_gather_output = [optimizer.meta_for_zero] + res = [metadata] + + return res + + +def generate_meta_data(optimizer): + if not gpc.config.ckpt.need_metadata: + return None + # dense_meta: [tp][pp][zero1] tp_parallel = ParallelMode.WEIGHT if is_using_isp() else ParallelMode.TENSOR - if gpc.get_world_size(tp_parallel) > 1: - dst = gpc.get_ranks_in_group(tp_parallel)[0] - if gpc.get_global_rank() == dst: - output = [None for _ in range(gpc.get_world_size(tp_parallel))] - else: - output = None + dense_meta = gather_meta(optimizer.meta_for_zero, ParallelMode.ZERO1) + dense_meta = gather_meta(dense_meta, ParallelMode.PIPELINE) + dense_meta = gather_meta(dense_meta, tp_parallel) + + # moe_meta: [pp][ep][edp][ewp] + moe_meta = None + if len(optimizer.moe_group) > 0: + tp_mode = "wp" if is_using_isp() else "tp" + rank_map = { + tp_mode: gpc.get_local_rank(ParallelMode.WEIGHT) + if is_using_isp() + else gpc.get_local_rank(ParallelMode.TENSOR), + "zero1": gpc.get_local_rank(ParallelMode.ZERO1), + } + moe_meta = { + "rank_map": rank_map, + "metaData": optimizer.meta_for_moe, + } - dist.gather_object(pp_gather_output, output, dst=dst, group=gpc.get_group(tp_parallel)) - final_output = output - else: - final_output = [pp_gather_output] + tp_parallel = ParallelMode.EXPERT_WEIGHT if is_using_isp() else ParallelMode.TENSOR + moe_no_tp = not is_using_isp() and gpc.config.parallel.expert.no_tp + moe_meta = gather_meta(moe_meta, tp_parallel, moe_no_tp=moe_no_tp) + moe_meta = gather_meta(moe_meta, ParallelMode.EXPERT_DATA) + moe_meta = gather_meta(moe_meta, ParallelMode.EXPERT) + moe_meta = gather_meta(moe_meta, ParallelMode.PIPELINE) if gpc.get_global_rank() == 0: - assert len(final_output) == gpc.get_world_size(tp_parallel) - assert len(final_output[0]) == gpc.get_world_size(ParallelMode.PIPELINE) - assert len(final_output[0][0]) == gpc.get_world_size(ParallelMode.ZERO1) tp_mode = "wp_size" if is_using_isp() else "tp_size" + if is_using_isp(): + ewp_size = gpc.get_world_size(ParallelMode.EXPERT_WEIGHT) + else: + if gpc.config.parallel.expert.no_tp: + ewp_size = 1 + else: + ewp_size = gpc.get_world_size(ParallelMode.TENSOR) + final_meta = { "parallel_setting": { tp_mode: gpc.get_world_size(tp_parallel), "pp_size": gpc.get_world_size(ParallelMode.PIPELINE), "zero1_size": gpc.get_world_size(ParallelMode.ZERO1), + "ep_size": gpc.get_world_size(ParallelMode.EXPERT), + "edp_size": gpc.get_world_size(ParallelMode.EXPERT_DATA), + "ewp_size": ewp_size, + "num_layers": gpc.config.model.num_layers, + "num_experts": gpc.config.model.num_experts if hasattr(gpc.config.model, "num_experts") else -1, }, - "metaData": final_output, + "metaData": dense_meta, + "moe_meta": moe_meta, + "moe_group": optimizer.moe_group, } if gpc.config.ckpt.generate_meta_data.enable: - save_path = os.path.join(gpc.config.ckpt.generate_meta_data.path, "metadata.pt") + file_path = gpc.config.ckpt.generate_meta_data.path + os.makedirs(file_path, exist_ok=True) + save_path = os.path.join(file_path, "metadata.pt") torch.save(final_meta, save_path) - logger.info(f"Successfully generate metadata.pt in {gpc.config.ckpt.generate_meta_data.path}") - + logger.info(f"Successfully generate metadata.pt in {save_path}") return final_meta return None @@ -764,6 +762,10 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList], isp_communicato zero_cfg=zero_cfg, ) + if not isinstance(optimizer, HybridZeroOptimizer): + gpc.config.ckpt.need_metadata = False + assert not gpc.config.ckpt.generate_meta_data.enable, "Only support generate_meta_data with HybridZeroOptimizer" + beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler) lr_scheduler = FineTuneCosineAnnealingWarmupLR(optimizer, **gpc.config.lr_scheduler) diff --git a/tools/convert_ckpt_parallel.py b/tools/convert_ckpt_parallel.py index 78a30cd3a..5757e8458 100644 --- a/tools/convert_ckpt_parallel.py +++ b/tools/convert_ckpt_parallel.py @@ -2,14 +2,15 @@ Usage: python tools/convert_ckpt_parallel.py \ \ - (optional) [--origin_meta_path ] [--target_meta_path ] \ - (optional) [--copy_file ] [--convert_optimizer ] + --origin_meta_path --target_meta_path \ + --copy_file --convert_optimizer When meta_path is not specified, it will automatically search and load meta in the ckpt path. Default to convert optimizer state and copy files. Example: srun -p llm_s python tools/convert_ckpt_parallel.py \ - /llm_ckpt/100 /target_ckpt/converted + /llm_ckpt/100 /target_ckpt/converted \ + --target_meta_path /target_ckpt/metadata.pt """ import argparse import os @@ -93,14 +94,42 @@ def unflatten_tensor(flat_tensor, states): unflat_tensors.append(tensor) start += size + length = flat_tensor.shape[0] + assert start == length, f"{start} should equal to {length}" + return unflat_tensors -def preprocess_optimizer_state(old_tp_size, old_pp_size, old_zero1_size, old_meta, folder, old_tp_mode): +def preprocess_optimizer_state( + old_tp_size, old_pp_size, old_zero1_size, old_meta, folder, old_tp_mode, old_base_groups, moe_group, **kwargs +): # preprocess optimizer_state to unflatten format processed_ckpt_states = [ [[{} for _ in range(old_zero1_size)] for _ in range(old_pp_size)] for _ in range(old_tp_size) ] + + if len(moe_group) > 0: + moe_meta = [[[{} for _ in range(old_zero1_size)] for _ in range(old_pp_size)] for _ in range(old_tp_size)] + moe_base_groups = [ + [[{} for _ in range(old_zero1_size)] for _ in range(old_pp_size)] for _ in range(old_tp_size) + ] + old_moe_base_groups = kwargs["old_moe_base_groups"] + for pp_rank in range(old_pp_size): + for ep_rank in range(kwargs["old_ep_size"]): + for edp_rank in range(kwargs["old_edp_size"]): + for ewp_rank in range(kwargs["old_ewp_size"]): + rank_map = old_meta["moe_meta"][pp_rank][ep_rank][edp_rank][ewp_rank]["rank_map"] + tp_rank = rank_map[kwargs["tp_mode"]] + zero1_rank = rank_map["zero1"] + assert moe_meta[tp_rank][pp_rank][zero1_rank] == {} + moe_meta[tp_rank][pp_rank][zero1_rank] = old_meta["moe_meta"][pp_rank][ep_rank][edp_rank][ + ewp_rank + ]["metaData"] + for group_id in moe_meta[tp_rank][pp_rank][zero1_rank]: + moe_base_groups[tp_rank][pp_rank][zero1_rank][group_id] = old_moe_base_groups[pp_rank][ + ep_rank + ][edp_rank][ewp_rank][group_id] + for old_tp_rank in range(old_tp_size): for old_pp_rank in range(old_pp_size): for old_zero1_rank in range(old_zero1_size): @@ -111,16 +140,34 @@ def preprocess_optimizer_state(old_tp_size, old_pp_size, old_zero1_size, old_met base_optim_states = ckpt_states["base_optim_states"]["state"] flat_fp32_weights = ckpt_states["flat_fp32_weights"] processed_state = ckpt_states - for group_id in list(base_optim_states.keys()): - exp_avg = base_optim_states[group_id]["exp_avg"] - exp_avg_sq = base_optim_states[group_id]["exp_avg_sq"] + + for group_id in list(flat_fp32_weights.keys()): + if group_id in moe_group: + metaData = moe_meta[old_tp_rank][old_pp_rank][old_zero1_rank][group_id] + base_group_id = moe_base_groups[old_tp_rank][old_pp_rank][old_zero1_rank][group_id][0] + assert ( + len(moe_base_groups[old_tp_rank][old_pp_rank][old_zero1_rank][group_id]) == 1 + ), "unsupported" + else: + metaData = old_meta["metaData"][old_tp_rank][old_pp_rank][old_zero1_rank][group_id] + base_group_id = old_base_groups[old_tp_rank][old_pp_rank][old_zero1_rank][group_id][0] + assert ( + len(old_base_groups[old_tp_rank][old_pp_rank][old_zero1_rank][group_id]) == 1 + ), "unsupported" + + exp_avg = base_optim_states[base_group_id]["exp_avg"] + exp_avg_sq = base_optim_states[base_group_id]["exp_avg_sq"] flat_tensor = flat_fp32_weights[group_id] - metaData = old_meta["metaData"][old_tp_rank][old_pp_rank][old_zero1_rank][group_id] unflat_exp_avg = unflatten_tensor(exp_avg, metaData) unflat_exp_avg_sq = unflatten_tensor(exp_avg_sq, metaData) unflat_tensor = unflatten_tensor(flat_tensor, metaData) + if group_id != base_group_id: + processed_state["base_optim_states"]["state"][group_id] = processed_state["base_optim_states"][ + "state" + ].pop(base_group_id) + processed_state["base_optim_states"]["state"][group_id]["exp_avg"] = unflat_exp_avg processed_state["base_optim_states"]["state"][group_id]["exp_avg_sq"] = unflat_exp_avg_sq processed_state["flat_fp32_weights"][group_id] = unflat_tensor @@ -158,6 +205,63 @@ def sort_optimizer_state(target_dict, meta_fqns): return sorted_exp_avg, sorted_exp_avg_sq, sorted_fp32_weights +def model_tp_merge( + old_pp_rank, + new_states, + old_tp_size, + new_tp_size, + old_tp_mode, + ratio, + old_meta_data, + new_meta_data, + old_map_local_to_global, + new_meta, + folder, +): + candidate_states = [defaultdict(list) for _ in range(new_tp_size)] + for old_tp_rank in range(old_tp_size): + ckpt_states = torch.load( + os.path.join(folder, f"model_{old_tp_mode}{old_tp_rank}_pp{old_pp_rank}.pt"), map_location="cpu" + ) + for fqn, tensor in ckpt_states.items(): + new_tp_rank = old_tp_rank // ratio + candidate_states[new_tp_rank][fqn].append(tensor) + + for new_tp_rank, states in enumerate(candidate_states): + for fqn, tensor_list in states.items(): + global_fqn = old_map_local_to_global[old_pp_rank][fqn] + tp_dim = new_meta_data[new_tp_rank][global_fqn]["tp_dim"] + assert tp_dim == old_meta_data[0][global_fqn]["tp_dim"], ( + f"{global_fqn} tp_dim in old and new meta are not equal: " + f"new={tp_dim}, old={old_meta_data[0][global_fqn]['tp_dim']}" + ) + + new_pp_rank = new_meta_data[new_tp_rank][global_fqn]["pp"] + new_zero1_rank = new_meta_data[new_tp_rank][global_fqn]["zero1"] + new_fqn = new_meta_data[new_tp_rank][global_fqn]["fqn"] + group_id = new_meta_data[new_tp_rank][global_fqn]["group_id"] + assert "bias" not in global_fqn + + if old_tp_size == new_tp_size or tp_dim == -1: + if old_tp_size == new_tp_size: + assert len(tensor_list) == 1 + else: + if "bias" not in global_fqn: + assert torch.equal(tensor_list[0], tensor_list[1]), ( + f"{global_fqn} should not be splited by tp, " + f"but the tensors in different checkpoints are not equal. {fqn}" + ) + new_states[new_tp_rank][new_pp_rank][new_fqn] = tensor_list[0].detach().clone() + else: + new_states[new_tp_rank][new_pp_rank][new_fqn] = torch.concat(tensor_list, dim=tp_dim).detach().clone() + + splited_shape = new_states[new_tp_rank][new_pp_rank][new_fqn].shape + meta_shape = new_meta["metaData"][new_tp_rank][new_pp_rank][new_zero1_rank][group_id][global_fqn]["shape"] + assert ( + splited_shape == meta_shape + ), f"{new_fqn}: splited shape {splited_shape} is not euqal to metaData {meta_shape}" + + def model_tp_split( split_maps, old_pp_rank, @@ -176,21 +280,29 @@ def model_tp_split( os.path.join(folder, f"model_{old_tp_mode}{old_tp_rank}_pp{old_pp_rank}.pt"), map_location="cpu" ) for fqn, tensor in ckpt_states.items(): - assert len(tensor.size()) < 3, "Only support 2D or 1D tensors." global_fqn = old_map_local_to_global[old_pp_rank][fqn] - tp_dim = old_meta_data[global_fqn]["tp_dim"] - assert tp_dim == new_meta_data[global_fqn]["tp_dim"], ( + tp_dim = old_meta_data[old_tp_rank][global_fqn]["tp_dim"] + assert tp_dim == new_meta_data[0][global_fqn]["tp_dim"], ( f"{global_fqn} tp_dim in old and new meta are not equal: " - f"old={tp_dim}, new={new_meta_data[fqn]['tp_dim']}" + f"old={tp_dim}, new={new_meta_data[0][fqn]['tp_dim']}" ) - new_pp_rank = new_meta_data[global_fqn]["pp"] - new_zero1_rank = new_meta_data[global_fqn]["zero1"] - new_fqn = new_meta_data[global_fqn]["fqn"] - group_id = new_meta_data[global_fqn]["group_id"] + if tp_dim != -1: + split_size = tensor.size()[tp_dim] // ratio + new_tp_splits = torch.split(tensor, split_size, dim=tp_dim) + + for i, new_tp_rank in enumerate(split_maps[old_tp_rank]): + # bias is not splitted and only exists in 0 rank for row tp + if tp_dim == -1: + if "bias" in global_fqn and new_tp_rank > 0: + break + + new_pp_rank = new_meta_data[new_tp_rank][global_fqn]["pp"] + new_zero1_rank = new_meta_data[new_tp_rank][global_fqn]["zero1"] + new_fqn = new_meta_data[new_tp_rank][global_fqn]["fqn"] + group_id = new_meta_data[new_tp_rank][global_fqn]["group_id"] - if tp_dim == -1: - for _, new_tp_rank in enumerate(split_maps[old_tp_rank]): + if tp_dim == -1: new_states[new_tp_rank][new_pp_rank][new_fqn] = tensor.detach().clone() splited_shape = new_states[new_tp_rank][new_pp_rank][new_fqn].shape meta_shape = new_meta["metaData"][new_tp_rank][new_pp_rank][new_zero1_rank][group_id][global_fqn][ @@ -199,10 +311,7 @@ def model_tp_split( assert ( splited_shape == meta_shape ), f"{new_fqn}: splited shape {splited_shape} is not euqal to metaData {meta_shape}" - else: - split_size = tensor.size()[tp_dim] // ratio - new_tp_splits = torch.split(tensor, split_size, dim=tp_dim) - for i, new_tp_rank in enumerate(split_maps[old_tp_rank]): + else: new_states[new_tp_rank][new_pp_rank][new_fqn] = new_tp_splits[i].detach().clone() splited_shape = new_states[new_tp_rank][new_pp_rank][new_fqn].shape meta_shape = new_meta["metaData"][new_tp_rank][new_pp_rank][new_zero1_rank][group_id][global_fqn][ @@ -213,6 +322,75 @@ def model_tp_split( ), f"{new_fqn}: splited shape {splited_shape} is not euqal to metaData {meta_shape}" +def optimizer_tp_merge( + new_tp_size, + old_tp_size, + old_pp_rank, + old_zero1_rank, + old_meta, + new_meta_data, + processed_ckpt_states, + new_states, + ratio, + moe_group, +): + candidate_exp_avg = [defaultdict(list) for _ in range(new_tp_size)] + candidate_exp_avg_sq = [defaultdict(list) for _ in range(new_tp_size)] + candidate_fp32_weights = [defaultdict(list) for _ in range(new_tp_size)] + for old_tp_rank in range(old_tp_size): + ckpt_states = processed_ckpt_states[old_tp_rank][old_pp_rank][old_zero1_rank] + for group_id in ckpt_states["flat_fp32_weights"].keys(): + if group_id in moe_group: + continue + old_metaData = old_meta["metaData"][old_tp_rank][old_pp_rank][old_zero1_rank][group_id] + exp_avg_list = ckpt_states["base_optim_states"]["state"][group_id]["exp_avg"] + exp_avg_sq_list = ckpt_states["base_optim_states"]["state"][group_id]["exp_avg_sq"] + fp32_weights_list = ckpt_states["flat_fp32_weights"][group_id] + new_tp_rank = old_tp_rank // ratio + for i, global_fqn in enumerate(list(old_metaData.keys())): + assert group_id == new_meta_data[0][global_fqn]["group_id"] + candidate_exp_avg[new_tp_rank][global_fqn].append(exp_avg_list[i]) + candidate_exp_avg_sq[new_tp_rank][global_fqn].append(exp_avg_sq_list[i]) + candidate_fp32_weights[new_tp_rank][global_fqn].append(fp32_weights_list[i]) + + for new_tp_rank in range(new_tp_size): + for global_fqn in candidate_fp32_weights[new_tp_rank].keys(): + splited_exp_avg = candidate_exp_avg[new_tp_rank][global_fqn] + splited_exp_avg_sq = candidate_exp_avg_sq[new_tp_rank][global_fqn] + splited_fp32_weights = candidate_fp32_weights[new_tp_rank][global_fqn] + + tp_dim = new_meta_data[new_tp_rank][global_fqn]["tp_dim"] + new_pp_rank = new_meta_data[new_tp_rank][global_fqn]["pp"] + new_zero1_rank = new_meta_data[new_tp_rank][global_fqn]["zero1"] + group_id = new_meta_data[new_tp_rank][global_fqn]["group_id"] + if group_id not in new_states[new_tp_rank][new_pp_rank][new_zero1_rank]: + new_states[new_tp_rank][new_pp_rank][new_zero1_rank][group_id] = {} + target_new_states = new_states[new_tp_rank][new_pp_rank][new_zero1_rank][group_id] + assert global_fqn not in target_new_states, f"repeated global_fqn {global_fqn}" + target_new_states[global_fqn] = {"exp_avg": None, "exp_avg_sq": None, "fp32_weights": None} + + if old_tp_size == new_tp_size or tp_dim == -1: + if old_tp_size == new_tp_size: + assert len(splited_fp32_weights) == 1 + else: + if "bias" not in global_fqn: + assert torch.equal(splited_fp32_weights[0], splited_fp32_weights[1]), ( + f"{global_fqn} should not be splited by tp," + "but the tensors in different checkpoints are not equal." + ) + target_new_states[global_fqn]["exp_avg"] = splited_exp_avg[0].detach().clone() + target_new_states[global_fqn]["exp_avg_sq"] = splited_exp_avg_sq[0].detach().clone() + target_new_states[global_fqn]["fp32_weights"] = splited_fp32_weights[0].detach().clone() + else: + target_new_states[global_fqn]["exp_avg"] = torch.concat(splited_exp_avg, dim=tp_dim).detach().clone() + target_new_states[global_fqn]["exp_avg_sq"] = ( + torch.concat(splited_exp_avg_sq, dim=tp_dim).detach().clone() + ) + target_new_states[global_fqn]["fp32_weights"] = ( + torch.concat(splited_fp32_weights, dim=tp_dim).detach().clone() + ) + + def optimizer_tp_split( split_maps, old_tp_size, @@ -223,10 +401,13 @@ def optimizer_tp_split( processed_ckpt_states, new_states, ratio, + moe_group, ): for old_tp_rank in range(old_tp_size): ckpt_states = processed_ckpt_states[old_tp_rank][old_pp_rank][old_zero1_rank] for group_id in ckpt_states["flat_fp32_weights"].keys(): + if group_id in moe_group: + continue old_metaData = old_meta["metaData"][old_tp_rank][old_pp_rank][old_zero1_rank][group_id] exp_avg_list = ckpt_states["base_optim_states"]["state"][group_id]["exp_avg"] exp_avg_sq_list = ckpt_states["base_optim_states"]["state"][group_id]["exp_avg_sq"] @@ -234,134 +415,203 @@ def optimizer_tp_split( for i, global_fqn in enumerate(list(old_metaData.keys())): tp_dim = old_metaData[global_fqn]["tp_dim"] - new_pp_rank = new_meta_data[global_fqn]["pp"] - new_zero1_rank = new_meta_data[global_fqn]["zero1"] - if tp_dim == -1: - for _, new_tp_rank in enumerate(split_maps[old_tp_rank]): - target_new_states = new_states[new_tp_rank][new_pp_rank][new_zero1_rank][group_id] - if global_fqn not in target_new_states: - target_new_states[global_fqn] = {"exp_avg": None, "exp_avg_sq": None, "fp32_weights": None} - target_new_states[global_fqn]["exp_avg"] = exp_avg_list[i].detach().clone() - target_new_states[global_fqn]["exp_avg_sq"] = exp_avg_sq_list[i].detach().clone() - target_new_states[global_fqn]["fp32_weights"] = fp32_weights_list[i].detach().clone() - else: + if tp_dim != -1: split_size = old_metaData[global_fqn]["shape"][tp_dim] // ratio new_exp_avg_splits = torch.split(exp_avg_list[i], split_size, dim=tp_dim) new_exp_avg_sq_splits = torch.split(exp_avg_sq_list[i], split_size, dim=tp_dim) new_fp32_weights_splits = torch.split(fp32_weights_list[i], split_size, dim=tp_dim) - for j, new_tp_rank in enumerate(split_maps[old_tp_rank]): + + for j, new_tp_rank in enumerate(split_maps[old_tp_rank]): + # bias is not splitted and only exists in 0 rank for row tp + if tp_dim == -1: + if "bias" in global_fqn and new_tp_rank > 0: + break + + new_pp_rank = new_meta_data[new_tp_rank][global_fqn]["pp"] + new_zero1_rank = new_meta_data[new_tp_rank][global_fqn]["zero1"] + + if group_id not in new_states[new_tp_rank][new_pp_rank][new_zero1_rank]: + new_states[new_tp_rank][new_pp_rank][new_zero1_rank][group_id] = {} + + if tp_dim == -1: target_new_states = new_states[new_tp_rank][new_pp_rank][new_zero1_rank][group_id] - if global_fqn not in target_new_states: - target_new_states[global_fqn] = {"exp_avg": None, "exp_avg_sq": None, "fp32_weights": None} + assert global_fqn not in target_new_states, f"repeated global_fqn {global_fqn}" + target_new_states[global_fqn] = {"exp_avg": None, "exp_avg_sq": None, "fp32_weights": None} + target_new_states[global_fqn]["exp_avg"] = exp_avg_list[i].detach().clone() + target_new_states[global_fqn]["exp_avg_sq"] = exp_avg_sq_list[i].detach().clone() + target_new_states[global_fqn]["fp32_weights"] = fp32_weights_list[i].detach().clone() + else: + target_new_states = new_states[new_tp_rank][new_pp_rank][new_zero1_rank][group_id] + assert global_fqn not in target_new_states, f"repeated global_fqn {global_fqn}" + target_new_states[global_fqn] = {"exp_avg": None, "exp_avg_sq": None, "fp32_weights": None} target_new_states[global_fqn]["exp_avg"] = new_exp_avg_splits[j].detach().clone() target_new_states[global_fqn]["exp_avg_sq"] = new_exp_avg_sq_splits[j].detach().clone() target_new_states[global_fqn]["fp32_weights"] = new_fp32_weights_splits[j].detach().clone() -def model_tp_merge( +def moe_tp_merge( + folder, + layer_id, + ep_id, + old_ewp_size, + new_ewp_size, old_pp_rank, - new_states, - old_tp_size, - new_tp_size, old_tp_mode, ratio, - old_meta_data, - new_meta_data, - old_map_local_to_global, - new_meta, - folder, + old_moe_map_local_to_global, + new_states, + old_moe_meta_data, + new_moe_meta_data, ): - candidate_states = [defaultdict(list) for _ in range(new_tp_size)] - for old_tp_rank in range(old_tp_size): + candidate_states = [defaultdict(list) for _ in range(new_ewp_size)] + for old_ewp_rank in range(old_ewp_size): ckpt_states = torch.load( - os.path.join(folder, f"model_{old_tp_mode}{old_tp_rank}_pp{old_pp_rank}.pt"), map_location="cpu" + os.path.join(folder, f"model_moe_layer{layer_id}_expert{ep_id}_{old_tp_mode}{old_ewp_rank}.pt"), + map_location="cpu", ) for fqn, tensor in ckpt_states.items(): - assert len(tensor.size()) < 3, "Only support 2D or 1D tensors." - new_tp_rank = old_tp_rank // ratio + new_tp_rank = old_ewp_rank // ratio candidate_states[new_tp_rank][fqn].append(tensor) for new_tp_rank, states in enumerate(candidate_states): for fqn, tensor_list in states.items(): - global_fqn = old_map_local_to_global[old_pp_rank][fqn] - tp_dim = old_meta_data[global_fqn]["tp_dim"] - assert tp_dim == new_meta_data[global_fqn]["tp_dim"], ( + global_fqn = old_moe_map_local_to_global[old_pp_rank][fqn] + tp_dim = new_moe_meta_data[global_fqn]["tp_dim"] + assert tp_dim == old_moe_meta_data[global_fqn]["tp_dim"], ( f"{global_fqn} tp_dim in old and new meta are not equal: " - f"old={tp_dim}, new={new_meta_data[fqn]['tp_dim']}" + f"new={tp_dim}, old={old_moe_meta_data[global_fqn]['tp_dim']}" ) - new_pp_rank = new_meta_data[global_fqn]["pp"] - new_zero1_rank = new_meta_data[global_fqn]["zero1"] - new_fqn = new_meta_data[global_fqn]["fqn"] - group_id = new_meta_data[global_fqn]["group_id"] + new_fqn = new_moe_meta_data[global_fqn]["fqn"] + assert tp_dim != -1 - if tp_dim == -1: - assert torch.equal( - tensor_list[0], tensor_list[1] - ), f"{global_fqn} should not be splited by tp, but the tensors in different checkpoints are not equal." - new_states[new_tp_rank][new_pp_rank][new_fqn] = tensor_list[0].detach().clone() + if old_ewp_size != new_ewp_size: + new_states[new_tp_rank][new_fqn] = torch.concat(tensor_list, dim=tp_dim).detach().clone() else: - new_states[new_tp_rank][new_pp_rank][new_fqn] = torch.concat(tensor_list, dim=tp_dim).detach().clone() + assert len(tensor_list) == 1 + new_states[new_tp_rank][new_fqn] = tensor_list[0].detach().clone() - splited_shape = new_states[new_tp_rank][new_pp_rank][new_fqn].shape - meta_shape = new_meta["metaData"][new_tp_rank][new_pp_rank][new_zero1_rank][group_id][global_fqn]["shape"] + splited_shape = new_states[new_tp_rank][new_fqn].shape + meta_shape = new_moe_meta_data[global_fqn]["shape"] assert ( splited_shape == meta_shape ), f"{new_fqn}: splited shape {splited_shape} is not euqal to metaData {meta_shape}" -def optimizer_tp_merge( - new_tp_size, - old_tp_size, +def moe_tp_split( + folder, + layer_id, + ep_id, + old_ewp_size, old_pp_rank, - old_zero1_rank, + old_tp_mode, + ratio, + split_maps, + old_moe_map_local_to_global, + new_states, + old_moe_meta_data, + new_moe_meta_data, +): + for old_ewp_rank in range(old_ewp_size): + ckpt_states = torch.load( + os.path.join(folder, f"model_moe_layer{layer_id}_expert{ep_id}_{old_tp_mode}{old_ewp_rank}.pt"), + map_location="cpu", + ) + for fqn, tensor in ckpt_states.items(): + global_fqn = old_moe_map_local_to_global[old_pp_rank][fqn] + tp_dim = old_moe_meta_data[global_fqn]["tp_dim"] + assert tp_dim == new_moe_meta_data[global_fqn]["tp_dim"], ( + f"{global_fqn} tp_dim in old and new meta are not equal: " + f"old={tp_dim}, new={new_moe_meta_data[global_fqn]['tp_dim']}" + ) + + split_size = tensor.size()[tp_dim] // ratio + new_tp_splits = torch.split(tensor, split_size, dim=tp_dim) + for i, new_tp_rank in enumerate(split_maps[old_ewp_rank]): + new_fqn = new_moe_meta_data[global_fqn]["fqn"] + new_states[new_tp_rank][new_fqn] = new_tp_splits[i].detach().clone() + splited_shape = new_states[new_tp_rank][new_fqn].shape + meta_shape = new_moe_meta_data[global_fqn]["shape"] + assert ( + splited_shape == meta_shape + ), f"{new_fqn}: splited shape {splited_shape} is not euqal to metaData {meta_shape}" + + +def optimizer_moe_tp_merge( + new_ewp_size, + old_ewp_size, + old_pp_rank, + old_ep_rank, + old_edp_rank, old_meta, - old_meta_data, - new_meta_data, + new_meta, + old_moe_meta_data, + new_moe_meta_data, + tp_mode, processed_ckpt_states, new_states, ratio, + moe_group, + moe_rank_map, ): - candidate_exp_avg = [defaultdict(list) for _ in range(new_tp_size)] - candidate_exp_avg_sq = [defaultdict(list) for _ in range(new_tp_size)] - candidate_fp32_weights = [defaultdict(list) for _ in range(new_tp_size)] - for old_tp_rank in range(old_tp_size): + candidate_exp_avg = [defaultdict(list) for _ in range(new_ewp_size)] + candidate_exp_avg_sq = [defaultdict(list) for _ in range(new_ewp_size)] + candidate_fp32_weights = [defaultdict(list) for _ in range(new_ewp_size)] + for old_ewp_rank in range(old_ewp_size): + old_rank_map = old_meta["moe_meta"][old_pp_rank][old_ep_rank][old_edp_rank][old_ewp_rank]["rank_map"] + old_tp_rank = old_rank_map[tp_mode] + old_zero1_rank = old_rank_map["zero1"] ckpt_states = processed_ckpt_states[old_tp_rank][old_pp_rank][old_zero1_rank] - for group_id in ckpt_states["flat_fp32_weights"].keys(): - old_metaData = old_meta["metaData"][old_tp_rank][old_pp_rank][old_zero1_rank][group_id] + + for group_id in moe_group: + old_metaData = old_meta["moe_meta"][old_pp_rank][old_ep_rank][old_edp_rank][old_ewp_rank]["metaData"][ + group_id + ] exp_avg_list = ckpt_states["base_optim_states"]["state"][group_id]["exp_avg"] exp_avg_sq_list = ckpt_states["base_optim_states"]["state"][group_id]["exp_avg_sq"] fp32_weights_list = ckpt_states["flat_fp32_weights"][group_id] - new_tp_rank = old_tp_rank // ratio + new_ewp_rank = old_ewp_rank // ratio for i, global_fqn in enumerate(list(old_metaData.keys())): - assert group_id == new_meta_data[global_fqn]["group_id"] - candidate_exp_avg[new_tp_rank][global_fqn].append(exp_avg_list[i]) - candidate_exp_avg_sq[new_tp_rank][global_fqn].append(exp_avg_sq_list[i]) - candidate_fp32_weights[new_tp_rank][global_fqn].append(fp32_weights_list[i]) + assert group_id == new_moe_meta_data[global_fqn]["group_id"] + candidate_exp_avg[new_ewp_rank][global_fqn].append(exp_avg_list[i]) + candidate_exp_avg_sq[new_ewp_rank][global_fqn].append(exp_avg_sq_list[i]) + candidate_fp32_weights[new_ewp_rank][global_fqn].append(fp32_weights_list[i]) + + for new_ewp_rank in range(new_ewp_size): + for global_fqn in candidate_fp32_weights[new_ewp_rank].keys(): + splited_exp_avg = candidate_exp_avg[new_ewp_rank][global_fqn] + splited_exp_avg_sq = candidate_exp_avg_sq[new_ewp_rank][global_fqn] + splited_fp32_weights = candidate_fp32_weights[new_ewp_rank][global_fqn] + + tp_dim = old_moe_meta_data[global_fqn]["tp_dim"] + assert tp_dim == new_moe_meta_data[global_fqn]["tp_dim"] + + new_pp_rank = new_moe_meta_data[global_fqn]["pp"] + new_ep_rank = new_moe_meta_data[global_fqn]["ep"] + new_edp_rank = new_moe_meta_data[global_fqn]["edp"] + group_id = new_moe_meta_data[global_fqn]["group_id"] + + new_rank_map = new_meta["moe_meta"][new_pp_rank][new_ep_rank][new_edp_rank][new_ewp_rank]["rank_map"] + new_tp_rank = new_rank_map[tp_mode] + new_zero1_rank = new_rank_map["zero1"] + key = f"{new_tp_rank}_{new_pp_rank}_{new_zero1_rank}" + if key not in moe_rank_map: + moe_rank_map[key] = {"ep_rank": new_ep_rank, "edp_rank": new_edp_rank, "ewp_rank": new_ewp_rank} + else: + assert list(moe_rank_map[key].items()) == list( + {"ep_rank": new_ep_rank, "edp_rank": new_edp_rank, "ewp_rank": new_ewp_rank}.items() + ) - for new_tp_rank in range(len(candidate_fp32_weights)): - for global_fqn in candidate_fp32_weights[new_tp_rank].keys(): - splited_exp_avg = candidate_exp_avg[new_tp_rank][global_fqn] - splited_exp_avg_sq = candidate_exp_avg_sq[new_tp_rank][global_fqn] - splited_fp32_weights = candidate_fp32_weights[new_tp_rank][global_fqn] + if group_id not in new_states[new_tp_rank][new_pp_rank][new_zero1_rank]: + new_states[new_tp_rank][new_pp_rank][new_zero1_rank][group_id] = {} - tp_dim = old_meta_data[global_fqn]["tp_dim"] - new_pp_rank = new_meta_data[global_fqn]["pp"] - new_zero1_rank = new_meta_data[global_fqn]["zero1"] - group_id = new_meta_data[global_fqn]["group_id"] target_new_states = new_states[new_tp_rank][new_pp_rank][new_zero1_rank][group_id] - if global_fqn not in target_new_states: - target_new_states[global_fqn] = {"exp_avg": None, "exp_avg_sq": None, "fp32_weights": None} + assert global_fqn not in target_new_states, f"repeated global_fqn {global_fqn}" + target_new_states[global_fqn] = {"exp_avg": None, "exp_avg_sq": None, "fp32_weights": None} - if tp_dim == -1: - assert torch.equal( - splited_fp32_weights[0], splited_fp32_weights[1] - ), f"{global_fqn} should not be splited by tp, but the tensors in different checkpoints are not equal." - target_new_states[global_fqn]["exp_avg"] = splited_exp_avg[0].detach().clone() - target_new_states[global_fqn]["exp_avg_sq"] = splited_exp_avg_sq[0].detach().clone() - target_new_states[global_fqn]["fp32_weights"] = splited_fp32_weights[0].detach().clone() - else: + assert tp_dim != -1 + if old_ewp_size != new_ewp_size: target_new_states[global_fqn]["exp_avg"] = torch.concat(splited_exp_avg, dim=tp_dim).detach().clone() target_new_states[global_fqn]["exp_avg_sq"] = ( torch.concat(splited_exp_avg_sq, dim=tp_dim).detach().clone() @@ -369,6 +619,143 @@ def optimizer_tp_merge( target_new_states[global_fqn]["fp32_weights"] = ( torch.concat(splited_fp32_weights, dim=tp_dim).detach().clone() ) + else: + assert len(splited_fp32_weights) == 1 + target_new_states[global_fqn]["exp_avg"] = splited_exp_avg[0].detach().clone() + target_new_states[global_fqn]["exp_avg_sq"] = splited_exp_avg_sq[0].detach().clone() + target_new_states[global_fqn]["fp32_weights"] = splited_fp32_weights[0].detach().clone() + + +def optimizer_moe_tp_split( + split_maps, + old_pp_rank, + old_ewp_size, + old_ep_rank, + old_edp_rank, + old_meta, + new_meta, + new_moe_meta_data, + tp_mode, + processed_ckpt_states, + new_states, + ratio, + moe_group, + moe_rank_map, +): + for old_ewp_rank in range(old_ewp_size): + old_rank_map = old_meta["moe_meta"][old_pp_rank][old_ep_rank][old_edp_rank][old_ewp_rank]["rank_map"] + old_tp_rank = old_rank_map[tp_mode] + old_zero1_rank = old_rank_map["zero1"] + ckpt_states = processed_ckpt_states[old_tp_rank][old_pp_rank][old_zero1_rank] + for group_id in moe_group: + old_metaData = old_meta["moe_meta"][old_pp_rank][old_ep_rank][old_edp_rank][old_ewp_rank]["metaData"][ + group_id + ] + exp_avg_list = ckpt_states["base_optim_states"]["state"][group_id]["exp_avg"] + exp_avg_sq_list = ckpt_states["base_optim_states"]["state"][group_id]["exp_avg_sq"] + fp32_weights_list = ckpt_states["flat_fp32_weights"][group_id] + + for i, global_fqn in enumerate(list(old_metaData.keys())): + tp_dim = old_metaData[global_fqn]["tp_dim"] + new_pp_rank = new_moe_meta_data[global_fqn]["pp"] + new_ep_rank = new_moe_meta_data[global_fqn]["ep"] + new_edp_rank = new_moe_meta_data[global_fqn]["edp"] + + if tp_dim != -1: + split_size = old_metaData[global_fqn]["shape"][tp_dim] // ratio + new_exp_avg_splits = torch.split(exp_avg_list[i], split_size, dim=tp_dim) + new_exp_avg_sq_splits = torch.split(exp_avg_sq_list[i], split_size, dim=tp_dim) + new_fp32_weights_splits = torch.split(fp32_weights_list[i], split_size, dim=tp_dim) + + for j, new_ewp_rank in enumerate(split_maps[old_ewp_rank]): + new_rank_map = new_meta["moe_meta"][new_pp_rank][new_ep_rank][new_edp_rank][new_ewp_rank][ + "rank_map" + ] + new_tp_rank = new_rank_map[tp_mode] + new_zero1_rank = new_rank_map["zero1"] + key = f"{new_tp_rank}_{new_pp_rank}_{new_zero1_rank}" + if key in moe_rank_map: + assert moe_rank_map[key]["ep_rank"] == new_ep_rank, "Error: Mapping exception occurred" + assert moe_rank_map[key]["edp_rank"] == new_edp_rank, "Error: Mapping exception occurred" + assert moe_rank_map[key]["ewp_rank"] == new_ewp_rank, "Error: Mapping exception occurred" + moe_rank_map[key] = {"ep_rank": new_ep_rank, "edp_rank": new_edp_rank, "ewp_rank": new_ewp_rank} + + if group_id not in new_states[new_tp_rank][new_pp_rank][new_zero1_rank]: + new_states[new_tp_rank][new_pp_rank][new_zero1_rank][group_id] = {} + target_new_states = new_states[new_tp_rank][new_pp_rank][new_zero1_rank][group_id] + assert global_fqn not in target_new_states, f"repeated global_fqn {global_fqn}" + target_new_states[global_fqn] = {"exp_avg": None, "exp_avg_sq": None, "fp32_weights": None} + + if tp_dim == -1: + target_new_states[global_fqn]["exp_avg"] = exp_avg_list[i].detach().clone() + target_new_states[global_fqn]["exp_avg_sq"] = exp_avg_sq_list[i].detach().clone() + target_new_states[global_fqn]["fp32_weights"] = fp32_weights_list[i].detach().clone() + else: + target_new_states[global_fqn]["exp_avg"] = new_exp_avg_splits[j].detach().clone() + target_new_states[global_fqn]["exp_avg_sq"] = new_exp_avg_sq_splits[j].detach().clone() + target_new_states[global_fqn]["fp32_weights"] = new_fp32_weights_splits[j].detach().clone() + + +def convert_moe_layer( + folder, + old_pp_size, + old_ewp_size, + new_ewp_size, + num_layers, + num_experts, + saved_folder, + old_tp_mode, + new_tp_mode, + old_moe_map_local_to_global, + old_moe_meta_data, + new_moe_meta_data, +): + print("Begin moe layer convert", flush=True) + for layer_id in range(num_layers): + old_pp_rank = layer_id // (num_layers // old_pp_size) + for ep_id in range(num_experts): + new_states = [{} for _ in range(new_ewp_size)] + if old_ewp_size >= new_ewp_size: + assert old_ewp_size % new_ewp_size == 0, f"Cannot convert {old_ewp_size} TP to {new_ewp_size} TP." + ratio = old_ewp_size // new_ewp_size + moe_tp_merge( + folder, + layer_id, + ep_id, + old_ewp_size, + new_ewp_size, + old_pp_rank, + old_tp_mode, + ratio, + old_moe_map_local_to_global, + new_states, + old_moe_meta_data, + new_moe_meta_data, + ) + else: + assert new_ewp_size % old_ewp_size == 0, f"Cannot convert {old_ewp_size} TP to {new_ewp_size} TP." + split_maps = get_mapping(old_ewp_size, new_ewp_size) + ratio = new_ewp_size // old_ewp_size + moe_tp_split( + folder, + layer_id, + ep_id, + old_ewp_size, + old_pp_rank, + old_tp_mode, + ratio, + split_maps, + old_moe_map_local_to_global, + new_states, + old_moe_meta_data, + new_moe_meta_data, + ) + + for new_ewp_rank in range(new_ewp_size): + file_name = f"model_moe_layer{layer_id}_expert{ep_id}_{new_tp_mode}{new_ewp_rank}.pt" + torch.save(new_states[new_ewp_rank], os.path.join(saved_folder, file_name)) + + print("Finish moe layer convert", flush=True) def convert_modeling_ckpt( @@ -384,57 +771,46 @@ def convert_modeling_ckpt( folder, saved_folder, new_states, + new_meta, ): print("Begin model convert", flush=True) for old_pp_rank in range(old_pp_size): - if old_tp_size != new_tp_size: - if old_tp_size > new_tp_size: - assert old_tp_size % new_tp_size == 0, f"Cannot convert {old_tp_size} TP to {new_tp_size} TP." - ratio = old_tp_size // new_tp_size - model_tp_merge( - old_pp_rank, - new_states, - old_tp_size, - new_tp_size, - old_tp_mode, - ratio, - old_meta_data, - new_meta_data, - old_map_local_to_global, - new_meta, - folder, - ) - else: - assert new_tp_size % old_tp_size == 0, f"Cannot convert {old_tp_size} TP to {new_tp_size} TP." - split_maps = get_mapping(old_tp_size, new_tp_size) - ratio = new_tp_size // old_tp_size - model_tp_split( - split_maps, - old_pp_rank, - old_tp_size, - new_states, - old_meta_data, - new_meta_data, - ratio, - old_tp_mode, - old_map_local_to_global, - new_meta, - folder, - ) + if old_tp_size >= new_tp_size: + assert old_tp_size % new_tp_size == 0, f"Cannot convert {old_tp_size} TP to {new_tp_size} TP." + ratio = old_tp_size // new_tp_size + model_tp_merge( + old_pp_rank, + new_states, + old_tp_size, + new_tp_size, + old_tp_mode, + ratio, + old_meta_data, + new_meta_data, + old_map_local_to_global, + new_meta, + folder, + ) else: - for old_tp_rank in range(old_tp_size): - ckpt_states = torch.load( - os.path.join(folder, f"model_{old_tp_mode}{old_tp_rank}_pp{old_pp_rank}.pt"), map_location="cpu" - ) - for fqn, tensor in ckpt_states.items(): - global_fqn = old_map_local_to_global[old_pp_rank][fqn] - new_pp_rank = new_meta_data[global_fqn]["pp"] - new_fqn = new_meta_data[global_fqn]["fqn"] - new_states[old_tp_rank][new_pp_rank][new_fqn] = tensor.detach().clone() + assert new_tp_size % old_tp_size == 0, f"Cannot convert {old_tp_size} TP to {new_tp_size} TP." + split_maps = get_mapping(old_tp_size, new_tp_size) + ratio = new_tp_size // old_tp_size + model_tp_split( + split_maps, + old_pp_rank, + old_tp_size, + new_states, + old_meta_data, + new_meta_data, + ratio, + old_tp_mode, + old_map_local_to_global, + new_meta, + folder, + ) for new_tp_rank in range(new_tp_size): for new_pp_rank in range(new_pp_size): - # print(f"pp={new_pp_rank}, tp={new_tp_rank}: {new_states[new_tp_rank][new_pp_rank].keys()}") file_name = f"model_{new_tp_mode}{new_tp_rank}_pp{new_pp_rank}.pt" states = sorted_state_dict(new_states[new_tp_rank][new_pp_rank]) torch.save(states, os.path.join(saved_folder, file_name)) @@ -451,79 +827,114 @@ def convert_optimizer_ckpt( new_tp_size, old_zero1_size, new_zero1_size, - old_meta_data, new_meta_data, new_tp_mode, saved_folder, new_states, processed_ckpt_states, + moe_group, + new_base_groups, + **kwargs, ): print("Begin optimizer convert", flush=True) for old_pp_rank in range(old_pp_size): for old_zero1_rank in range(old_zero1_size): - if old_tp_size != new_tp_size: - if old_tp_size > new_tp_size: - assert old_tp_size % new_tp_size == 0, f"Cannot convert {old_tp_size} TP to {new_tp_size} TP." - ratio = old_tp_size // new_tp_size - optimizer_tp_merge( - new_tp_size, - old_tp_size, - old_pp_rank, - old_zero1_rank, - old_meta, - old_meta_data, - new_meta_data, - processed_ckpt_states, - new_states, - ratio, - ) - else: - assert new_tp_size % old_tp_size == 0, f"Cannot convert {old_tp_size} TP to {new_tp_size} TP." - split_maps = get_mapping(old_tp_size, new_tp_size) - ratio = new_tp_size // old_tp_size - optimizer_tp_split( - split_maps, - old_tp_size, - old_pp_rank, - old_zero1_rank, - old_meta, - new_meta_data, - processed_ckpt_states, - new_states, - ratio, - ) + if old_tp_size >= new_tp_size: + assert old_tp_size % new_tp_size == 0, f"Cannot convert {old_tp_size} TP to {new_tp_size} TP." + ratio = old_tp_size // new_tp_size + optimizer_tp_merge( + new_tp_size, + old_tp_size, + old_pp_rank, + old_zero1_rank, + old_meta, + new_meta_data, + processed_ckpt_states, + new_states, + ratio, + moe_group, + ) else: - for old_tp_rank in range(old_tp_size): - ckpt_states = processed_ckpt_states[old_tp_rank][old_pp_rank][old_zero1_rank] - for group_id in ckpt_states["flat_fp32_weights"].keys(): - old_metaData = old_meta["metaData"][old_tp_rank][old_pp_rank][old_zero1_rank][group_id] - exp_avg_list = ckpt_states["base_optim_states"]["state"][group_id]["exp_avg"] - exp_avg_sq_list = ckpt_states["base_optim_states"]["state"][group_id]["exp_avg_sq"] - fp32_weights_list = ckpt_states["flat_fp32_weights"][group_id] - for i, global_fqn in enumerate(list(old_metaData.keys())): - new_pp_rank = new_meta_data[global_fqn]["pp"] - new_zero1_rank = new_meta_data[global_fqn]["zero1"] - target_new_states = new_states[old_tp_rank][new_pp_rank][new_zero1_rank][group_id] - if global_fqn not in target_new_states: - target_new_states[global_fqn] = { - "exp_avg": None, - "exp_avg_sq": None, - "fp32_weights": None, - } - target_new_states[global_fqn]["exp_avg"] = exp_avg_list[i].detach().clone() - target_new_states[global_fqn]["exp_avg_sq"] = exp_avg_sq_list[i].detach().clone() - target_new_states[global_fqn]["fp32_weights"] = fp32_weights_list[i].detach().clone() + assert new_tp_size % old_tp_size == 0, f"Cannot convert {old_tp_size} TP to {new_tp_size} TP." + split_maps = get_mapping(old_tp_size, new_tp_size) + ratio = new_tp_size // old_tp_size + optimizer_tp_split( + split_maps, + old_tp_size, + old_pp_rank, + old_zero1_rank, + old_meta, + new_meta_data, + processed_ckpt_states, + new_states, + ratio, + moe_group, + ) + + if len(moe_group) > 0: + print("Begin optimizer moe group convert", flush=True) + old_ep_size = kwargs["old_ep_size"] + old_edp_size = kwargs["old_edp_size"] + old_ewp_size = kwargs["old_ewp_size"] + new_ewp_size = kwargs["new_ewp_size"] + tp_mode = kwargs["tp_mode"] + old_moe_meta_data = kwargs["old_moe_meta_data"] + new_moe_meta_data = kwargs["new_moe_meta_data"] + moe_rank_map = {} + + for old_pp_rank in range(old_pp_size): + for old_ep_rank in range(old_ep_size): + for old_edp_rank in range(old_edp_size): + if old_ewp_size >= new_ewp_size: + assert ( + old_ewp_size % new_ewp_size == 0 + ), f"Cannot convert {old_ewp_size} ewp/tp to {new_ewp_size} ewp/tp." + ratio = old_ewp_size // new_ewp_size + optimizer_moe_tp_merge( + new_ewp_size, + old_ewp_size, + old_pp_rank, + old_ep_rank, + old_edp_rank, + old_meta, + new_meta, + old_moe_meta_data, + new_moe_meta_data, + tp_mode, + processed_ckpt_states, + new_states, + ratio, + moe_group, + moe_rank_map, + ) + else: + assert ( + new_ewp_size % old_ewp_size == 0 + ), f"Cannot convert {old_ewp_size} ewp/tp to {new_ewp_size} ewp/tp." + split_maps = get_mapping(old_ewp_size, new_ewp_size) + ratio = new_ewp_size // old_ewp_size + optimizer_moe_tp_split( + split_maps, + old_pp_rank, + old_ewp_size, + old_ep_rank, + old_edp_rank, + old_meta, + new_meta, + new_moe_meta_data, + tp_mode, + processed_ckpt_states, + new_states, + ratio, + moe_group, + moe_rank_map, + ) for new_tp_rank in range(new_tp_size): for new_pp_rank in range(new_pp_size): for new_zero1_rank in range(new_zero1_size): file_name = f"optimizer_{new_tp_mode}{new_tp_rank}_pp{new_pp_rank}_zo{new_zero1_rank}.pt" optimizer_state = new_states[new_tp_rank][new_pp_rank][new_zero1_rank] - metaData = new_meta["metaData"][new_tp_rank][new_pp_rank][new_zero1_rank] - assert set(optimizer_state.keys()) == set( - metaData.keys() - ), f"group_id error: state {list((optimizer_state.keys()))} is different from {list(metaData.keys())}" - base_state = processed_ckpt_states[0][0][0] step = base_state["base_optim_states"]["state"][0]["step"] base_state["base_optim_states"]["state"] = {} @@ -531,7 +942,18 @@ def convert_optimizer_ckpt( if "zero_devide_optim_plan" in base_state: base_state.pop("zero_devide_optim_plan") - for group_id in optimizer_state.keys(): + # Ensure that the order of group_id is consistent + sorted_groups = sorted(list(optimizer_state.keys())) + for group_id in sorted_groups: + if group_id in moe_group: + key = f"{new_tp_rank}_{new_pp_rank}_{new_zero1_rank}" + ep_rank = moe_rank_map[key]["ep_rank"] + edp_rank = moe_rank_map[key]["edp_rank"] + ewp_rank = moe_rank_map[key]["ewp_rank"] + metaData = new_meta["moe_meta"][new_pp_rank][ep_rank][edp_rank][ewp_rank]["metaData"] + else: + metaData = new_meta["metaData"][new_tp_rank][new_pp_rank][new_zero1_rank] + meta_fqns = metaData[group_id].keys() sorted_exp_avg, sorted_exp_avg_sq, sorted_fp32_weights = sort_optimizer_state( optimizer_state[group_id], meta_fqns @@ -543,6 +965,31 @@ def convert_optimizer_ckpt( base_state["base_optim_states"]["state"][group_id] = state base_state["flat_fp32_weights"][group_id] = flat_fp32_weights + # overwrite base states group_id + base_groups = new_base_groups[new_tp_rank][new_pp_rank][new_zero1_rank] + for group_id in base_groups: + if group_id not in base_state["base_optim_states"]["state"]: + continue + base_group_id = base_groups[group_id][0] + if group_id in base_state["base_optim_states"]["state"]: + assert len(base_groups[group_id]) == 1 + if group_id != base_group_id: + base_state["base_optim_states"]["state"][base_group_id] = base_state["base_optim_states"][ + "state" + ].pop(group_id) + base_state["base_optim_states"]["param_groups"][group_id]["params"] = base_groups[group_id] + + if len(moe_group) > 0: + base_groups = kwargs["new_moe_base_groups"][new_pp_rank][ep_rank][edp_rank][ewp_rank] + for group_id in base_groups: + base_group_id = base_groups[group_id][0] + if group_id in base_state["base_optim_states"]["state"]: + assert len(base_groups[group_id]) == 1 + if group_id != base_group_id: + base_state["base_optim_states"]["state"][base_group_id] = base_state[ + "base_optim_states" + ]["state"].pop(group_id) + base_state["base_optim_states"]["param_groups"][group_id]["params"] = base_groups[group_id] torch.save(base_state, os.path.join(saved_folder, file_name)) print("Finish optimizer convert", flush=True) @@ -575,6 +1022,7 @@ def convert_optimizer_ckpt( old_meta = torch.load(old_meta_path, map_location="cpu") old_pp_size = old_meta["parallel_setting"]["pp_size"] old_zero1_size = old_meta["parallel_setting"]["zero1_size"] + if "tp_size" in old_meta["parallel_setting"]: old_tp_mode = "tp" elif "wp_size" in old_meta["parallel_setting"]: @@ -583,22 +1031,39 @@ def convert_optimizer_ckpt( assert False, "tp or wp should be in parallel setting." old_tp_size = old_meta["parallel_setting"][f"{old_tp_mode}_size"] - # To facilitate key query, summarize meta_data. - old_meta_data = {} - for pp_rank in range(old_pp_size): - for zero_rank in range(old_zero1_size): - for states in old_meta["metaData"][0][pp_rank][zero_rank].values(): - old_meta_data.update(states) + # preprocess base group_id + # old_base_groups[tp_pp_zero1_groupId] = [base_group_id] + old_base_groups = [[[{} for _ in range(old_zero1_size)] for _ in range(old_pp_size)] for _ in range(old_tp_size)] + for tp_rank in range(old_tp_size): + for pp_rank in range(old_pp_size): + for zero_rank in range(old_zero1_size): + old_base_groups[tp_rank][pp_rank][zero_rank] = old_meta["metaData"][tp_rank][pp_rank][zero_rank].pop( + "base_groups" + ) + + # To facilitate key query, aggregate meta_data. + old_meta_data = [{} for _ in range(old_tp_size)] + for tp_rank in range(old_tp_size): + for pp_rank in range(old_pp_size): + for zero_rank in range(old_zero1_size): + assert "base_groups" not in old_meta["metaData"][tp_rank][pp_rank][zero_rank] + for states in old_meta["metaData"][tp_rank][pp_rank][zero_rank].values(): + old_meta_data[tp_rank].update(states) # map local fqn to global fqn old_map_local_to_global = [{} for _ in range(old_pp_size)] - for global_fqn, states in old_meta_data.items(): - old_map_local_to_global[states["pp"]][states["fqn"]] = global_fqn + for tp_rank in range(old_tp_size): + for global_fqn, states in old_meta_data[tp_rank].items(): + if states["fqn"] not in old_map_local_to_global[states["pp"]]: + old_map_local_to_global[states["pp"]][states["fqn"]] = global_fqn + else: + assert global_fqn == old_map_local_to_global[states["pp"]][states["fqn"]] # read and process metaData for target ckpt new_meta = torch.load(new_meta_path, map_location="cpu") new_pp_size = new_meta["parallel_setting"]["pp_size"] new_zero1_size = new_meta["parallel_setting"]["zero1_size"] + if "tp_size" in new_meta["parallel_setting"]: new_tp_mode = "tp" elif "wp_size" in new_meta["parallel_setting"]: @@ -608,17 +1073,96 @@ def convert_optimizer_ckpt( # TODO: support converting between tp and wp assert old_tp_mode == new_tp_mode, "Do not support converting between tp and wp currently." new_tp_size = new_meta["parallel_setting"][f"{new_tp_mode}_size"] + + # preprocess base group_id + # new_base_groups[tp_pp_zero1_groupId] = [base_group_id] + new_base_groups = [[[{} for _ in range(new_zero1_size)] for _ in range(new_pp_size)] for _ in range(new_tp_size)] + for tp_rank in range(new_tp_size): + for pp_rank in range(new_pp_size): + for zero_rank in range(new_zero1_size): + new_base_groups[tp_rank][pp_rank][zero_rank] = new_meta["metaData"][tp_rank][pp_rank][zero_rank].pop( + "base_groups" + ) + assert set(new_meta["metaData"][0][0][0].keys()) == set( old_meta["metaData"][0][0][0].keys() ), "Error: old meta and new meta have diffent group_id lists." group_id_list = list(new_meta["metaData"][0][0][0].keys()) - # To facilitate key query, summarize meta_data. - new_meta_data = {} - for pp_rank in range(new_pp_size): - for zero_rank in range(new_zero1_size): - for states in new_meta["metaData"][0][pp_rank][zero_rank].values(): - new_meta_data.update(states) + # To facilitate key query, aggregate meta_data. + new_meta_data = [{} for _ in range(new_tp_size)] + for tp_rank in range(new_tp_size): + for pp_rank in range(new_pp_size): + for zero_rank in range(new_zero1_size): + assert "base_groups" not in new_meta["metaData"][tp_rank][pp_rank][zero_rank] + for states in new_meta["metaData"][tp_rank][pp_rank][zero_rank].values(): + new_meta_data[tp_rank].update(states) + + # moe + num_layers = old_meta["parallel_setting"]["num_layers"] + num_experts = old_meta["parallel_setting"]["num_experts"] + assert num_layers == new_meta["parallel_setting"]["num_layers"] + assert num_experts == new_meta["parallel_setting"]["num_experts"] + if num_experts > 1: + old_ep_size = old_meta["parallel_setting"]["ep_size"] + old_edp_size = old_meta["parallel_setting"]["edp_size"] + old_ewp_size = old_meta["parallel_setting"]["ewp_size"] + new_ep_size = new_meta["parallel_setting"]["ep_size"] + new_edp_size = new_meta["parallel_setting"]["edp_size"] + new_ewp_size = new_meta["parallel_setting"]["ewp_size"] + + assert len(old_meta["moe_group"]) > 0, "moe group should not be empty" + assert ( + old_meta["moe_group"] == new_meta["moe_group"] + ), "Error: old meta and new meta have diffent moe grou lists." + + # preprocess base group_id + # old_moe_base_groups[pp_ep_edp_ewp_groupId] = [base_group_id] + old_moe_base_groups = [ + [[[{} for _ in range(old_ewp_size)] for _ in range(old_edp_size)] for _ in range(old_ep_size)] + for _ in range(old_pp_size) + ] + for pp_rank in range(old_pp_size): + for ep_rank in range(old_ep_size): + for edp_rank in range(old_edp_size): + for ewp_rank in range(old_ewp_size): + old_moe_base_groups[pp_rank][ep_rank][edp_rank][ewp_rank] = old_meta["moe_meta"][pp_rank][ + ep_rank + ][edp_rank][ewp_rank]["metaData"].pop("base_groups") + + old_moe_meta_data = {} + for pp_rank in range(old_pp_size): + for ep_rank in range(old_ep_size): + for edp_rank in range(old_edp_size): + assert "base_groups" not in old_meta["moe_meta"][pp_rank][ep_rank][edp_rank][0]["metaData"] + for states in old_meta["moe_meta"][pp_rank][ep_rank][edp_rank][0]["metaData"].values(): + old_moe_meta_data.update(states) + + old_moe_map_local_to_global = [{} for _ in range(old_pp_size)] + for global_fqn, states in old_moe_meta_data.items(): + old_moe_map_local_to_global[states["pp"]][states["fqn"]] = global_fqn + + # preprocess base group_id + # new_moe_base_groups[pp_ep_edp_ewp_groupId] = [base_group_id] + new_moe_base_groups = [ + [[[{} for _ in range(new_ewp_size)] for _ in range(new_edp_size)] for _ in range(new_ep_size)] + for _ in range(new_pp_size) + ] + for pp_rank in range(new_pp_size): + for ep_rank in range(new_ep_size): + for edp_rank in range(new_edp_size): + for ewp_rank in range(new_ewp_size): + new_moe_base_groups[pp_rank][ep_rank][edp_rank][ewp_rank] = new_meta["moe_meta"][pp_rank][ + ep_rank + ][edp_rank][ewp_rank]["metaData"].pop("base_groups") + + new_moe_meta_data = {} + for pp_rank in range(new_pp_size): + for ep_rank in range(new_ep_size): + for edp_rank in range(new_edp_size): + assert "base_groups" not in new_meta["moe_meta"][pp_rank][ep_rank][edp_rank][0]["metaData"] + for states in new_meta["moe_meta"][pp_rank][ep_rank][edp_rank][0]["metaData"].values(): + new_moe_meta_data.update(states) new_states = [[{} for _ in range(new_pp_size)] for _ in range(new_tp_size)] convert_modeling_ckpt( @@ -634,15 +1178,53 @@ def convert_optimizer_ckpt( folder=folder, saved_folder=saved_folder, new_states=new_states, + new_meta=new_meta, ) + if num_experts > 1: + convert_moe_layer( + folder=folder, + old_pp_size=old_pp_size, + old_ewp_size=old_ewp_size, + new_ewp_size=new_ewp_size, + num_layers=num_layers, + num_experts=num_experts, + saved_folder=saved_folder, + old_tp_mode=old_tp_mode, + new_tp_mode=new_tp_mode, + old_moe_map_local_to_global=old_moe_map_local_to_global, + old_moe_meta_data=old_moe_meta_data, + new_moe_meta_data=new_moe_meta_data, + ) + if args.convert_optimizer: + if num_experts > 1: + kwargs = { + "old_ep_size": old_ep_size, + "old_edp_size": old_edp_size, + "old_ewp_size": old_ewp_size, + "new_ewp_size": new_ewp_size, + "tp_mode": old_tp_mode, + "old_moe_meta_data": old_moe_meta_data, + "new_moe_meta_data": new_moe_meta_data, + "old_moe_base_groups": old_moe_base_groups, + "new_moe_base_groups": new_moe_base_groups, + } + else: + kwargs = {} + processed_ckpt_states = preprocess_optimizer_state( - old_tp_size, old_pp_size, old_zero1_size, old_meta, folder, old_tp_mode + old_tp_size, + old_pp_size, + old_zero1_size, + old_meta, + folder, + old_tp_mode, + old_base_groups, + old_meta["moe_group"], + **kwargs, ) - new_states = [ - [[defaultdict(dict) for _ in range(new_zero1_size)] for _ in range(new_pp_size)] for _ in range(new_tp_size) - ] + new_states = [[[{} for _ in range(new_zero1_size)] for _ in range(new_pp_size)] for _ in range(new_tp_size)] convert_optimizer_ckpt( old_meta=old_meta, new_meta=new_meta, @@ -652,12 +1234,14 @@ def convert_optimizer_ckpt( new_tp_size=new_tp_size, old_zero1_size=old_zero1_size, new_zero1_size=new_zero1_size, - old_meta_data=old_meta_data, new_meta_data=new_meta_data, new_tp_mode=new_tp_mode, saved_folder=saved_folder, new_states=new_states, processed_ckpt_states=processed_ckpt_states, + moe_group=old_meta["moe_group"], + new_base_groups=new_base_groups, + **kwargs, ) if args.copy_file: