diff --git a/deepspeed/checkpoint/__init__.py b/deepspeed/checkpoint/__init__.py index e69de29bb2d1..edb424e9dfa8 100644 --- a/deepspeed/checkpoint/__init__.py +++ b/deepspeed/checkpoint/__init__.py @@ -0,0 +1,13 @@ +from .reshape_meg_2d import reshape_meg_2d_parallel + +from .deepspeed_checkpoint import DeepSpeedCheckpoint + +from .utils import (get_layer_ckpt_name_for_rank, + get_model_ckpt_name_for_rank, + get_zero_ckpt_name_for_rank) + +from .reshape_utils import (merge_state) + +from .reshape_3d_utils import (model_3d_desc, get_model_3d_descriptor) + +from .zero_checkpoint import ZeROCheckpoint diff --git a/deepspeed/checkpoint/constants.py b/deepspeed/checkpoint/constants.py index f45777025db7..dc79df643af2 100644 --- a/deepspeed/checkpoint/constants.py +++ b/deepspeed/checkpoint/constants.py @@ -11,15 +11,30 @@ BASE_OPTIMIZER_STATE = 'base_optimizer_state' SINGLE_PARTITION_OF_FP32_GROUPS = "single_partition_of_fp32_groups" -GROUPS_PADDING = 'groups_padding' - +GROUP_PADDINGS = 'group_paddings' PARTITION_COUNT = 'partition_count' ZERO_STAGE = 'zero_stage' CLIP_GRAD = 'clip_grad' +PARAM_SLICE_MAPPINGS = 'param_slice_mappings' +FP32_WEIGHT_KEY = "fp32" ######################################### # Module checkpoint keys ######################################### PARAM_SHAPES = 'param_shapes' BUFFER_NAMES = 'buffer_names' + +######################################### +# Checkpoint naming constants +######################################### +MODEL_FILE_PREFIX = 'mp_rank_' +ZERO_FILE_PREFIX = 'bf16_' + 'zero_pp_rank_' +OPTIM_FILE_SUFFIX = '_optim_states.pt' +MODEL_FILE_SUFFIX = '_model_states.pt' +LAYER_FILE_PREFIX = 'layer_' +BF16_ZERO_FILE_PREFIX = ZERO_FILE_PREFIX + +######################################### +# Checkpoint utility keys +######################################### DS_VERSION = 'ds_version' diff --git a/deepspeed/checkpoint/deepspeed_checkpoint.py b/deepspeed/checkpoint/deepspeed_checkpoint.py new file mode 100644 index 000000000000..4b8d31e832d7 --- /dev/null +++ b/deepspeed/checkpoint/deepspeed_checkpoint.py @@ -0,0 +1,316 @@ +import os +from typing import Dict +import torch + +from .reshape_3d_utils import model_3d_desc +from .reshape_utils import (basic_folder_validation, + merge_state, + partition_data, + get_files, + get_files_with_prefix) + +from .constants import (ZERO_FILE_PREFIX, MODEL_FILE_PREFIX, LAYER_FILE_PREFIX) + +from .reshape_meg_2d import reshape_meg_2d_parallel, meg_2d_parallel_map +from .zero_checkpoint import ZeROCheckpoint +from .constants import * + +EMBEDDING_LAYER_INDEX = 0 +FINAL_LAYER_NORM_INDEX = -1 +ARGS_KEY = 'args' +CHECKPOINT_INFO_KEY = 'checkpoint_info' +ITERATION_KEY = 'iteration' + +SEQUENTIAL_LAYERS = [ + 'input_layernorm.weight', + 'input_layernorm.bias', + 'self_attention.dense.bias', + 'post_attention_layernorm.weight', + 'post_attention_layernorm.bias', + 'mlp.dense_4h_to_h.bias', + 'position_embeddings.weight' +] + +LAYER_CONCAT_DIM = {'self_attention.dense.weight': 1, 'mlp.dense_4h_to_h.weight': 1} + + +class DeepSpeedCheckpoint(object): + def __init__(self, dir, tp_degree=None, pp_degree=None, dp_degree=None): + self.dir = dir + self._validate_folder(dir) + + self.file_list = get_files(dir) + self.zero_files = get_files_with_prefix(self.file_list, ZERO_FILE_PREFIX) + self.layer_files = get_files_with_prefix(self.file_list, LAYER_FILE_PREFIX) + self.mp_rank_files = get_files_with_prefix(self.file_list, MODEL_FILE_PREFIX) + + self.layer_keys = self._get_layer_keys() + self.layer_count = len(self.layer_keys) + self.original_tp_degree = len( + get_files_with_prefix(self.layer_files, + f'{LAYER_FILE_PREFIX}01')) + self.original_pp_degree = len(self.mp_rank_files) // self.original_tp_degree + self.original_dp_degree = max( + 1, + len(self.zero_files) // (self.original_pp_degree * self.original_tp_degree)) + + self.tp_degree = self.original_tp_degree if tp_degree is None else tp_degree + self.pp_degree = self.original_pp_degree if pp_degree is None else pp_degree + self.dp_degree = self.original_dp_degree if dp_degree is None else dp_degree + + self.original_world_size = self.original_tp_degree * self.original_pp_degree * self.original_dp_degree + self.world_size = self.tp_degree * self.pp_degree * self.dp_degree + + self.old_2d_map = meg_2d_parallel_map(self.original_pp_degree, + self.original_tp_degree) + self.old_2d_map.simple_init() + self.new_2d_map = reshape_meg_2d_parallel(old_pp_degree=self.original_pp_degree, + old_tp_degree=self.original_tp_degree, + new_pp_degree=self.pp_degree, + new_tp_degree=self.tp_degree) + + self.zero_checkpoint = ZeROCheckpoint(dir) + if self.is_change_pp_degree() or self.is_change_tp_degree( + ) or self.is_change_dp_degree(): + self.zero_checkpoint.reshape( + model_3d_desc(self.pp_degree, + self.tp_degree, + self.dp_degree)) + + self.global_state = {} + + self._sanity_check() + self.pp_to_transformer_map = self._build_pp_transformer_map() + self.transformer_file_map = self._build_transformer_file_map() + self.tp_to_embedding_map = self._build_tp_other_layer_map(EMBEDDING_LAYER_INDEX) + self.tp_to_final_norm_map = self._build_tp_other_layer_map( + FINAL_LAYER_NORM_INDEX) + self._build_global_state() + + def is_change_tp_degree(self): + return self.tp_degree != self.original_tp_degree + + def is_change_pp_degree(self): + return self.pp_degree != self.original_pp_degree + + def is_change_dp_degree(self): + return self.dp_degree != self.original_dp_degree + + def show_2d_mapping(self): + print(f'reshaped 2d map ---- begin') + + for i in range(self.pp_degree): + for j in range(self.tp_degree): + file_list = self.get_2d_parallel_files(pp_index=i, tp_index=j) + print(f'[{i}, {j}] = {file_list}') + + print(f'reshaped 2d map ---- end') + + def show_tp_embedding_map(self): + self._dump_mapping(self.tp_to_embedding_map, 'tp_to_embedding_layers') + + def show_tp_final_norm_map(self): + self._dump_mapping(self.tp_to_final_norm_map, 'tp_to_final_norm_layers') + + def show_pp_tranformer_map(self): + self._dump_mapping(self.pp_to_transformer_map, 'pp_to_tranformer_layers') + + def show_transformer_file_map(self): + self._dump_mapping(self.transformer_file_map, 'rank_to_tranformer_files') + + def _build_global_state(self): + sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu')) + self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0) + self.global_state[ARGS_KEY] = sd.get(ARGS_KEY, None) + + def get_zero_checkpoint_state(self, pp_index, tp_index, dp_index) -> dict: + return self.zero_checkpoint.get_state_for_rank(pp_index=pp_index, + tp_index=tp_index, + dp_index=dp_index, + keys_to_ignore=[PARAM_SHAPES]) + + def get_zero_files(self, pp_index, tp_index, dp_index) -> list: + return self.zero_checkpoint.get_files_for_rank(pp_index=pp_index, + tp_index=tp_index, + dp_index=dp_index) + + def get_embedding_layer_id(self): + return self.layer_keys[EMBEDDING_LAYER_INDEX] + + def get_final_norm_layer_id(self): + return self.layer_keys[FINAL_LAYER_NORM_INDEX] + + def get_iteration(self): + if not ITERATION_KEY in self.global_state: + sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu')) + self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0) + + return self.global_state[ITERATION_KEY] + + def get_embedding_state(self, tp_index: int) -> Dict: + assert tp_index in self.tp_to_embedding_map.keys() + sd_list = [ + torch.load(fname, + map_location=torch.device('cpu')) + for fname in self.tp_to_embedding_map[tp_index] + ] + sd = self._merge_state_dicts(sd_list) + return sd + + def get_embedding_files(self, tp_index: int) -> list: + assert tp_index in self.tp_to_embedding_map.keys() + return self.tp_to_embedding_map[tp_index] + + def _get_checkpoint_value(self, key): + if not key in self.global_state: + sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu')) + self.global_state[key] = sd.get(key, None) + + return self.global_state[key] + + def get_args(self): + return self._get_checkpoint_value(ARGS_KEY) + + def get_checkpoint_info(self): + return self._get_checkpoint_value(CHECKPOINT_INFO_KEY) + + def get_2d_parallel_state(self, tp_index: int, pp_index: int) -> dict: + assert tp_index < self.tp_degree + assert pp_index < self.pp_degree + fname_list = self.get_2d_parallel_files(tp_index=tp_index, pp_index=pp_index) + sd_list = [ + torch.load(fname, + map_location=torch.device('cpu')) for fname in fname_list + ] + + merged_sd = None + for sd in sd_list: + if merged_sd is None: + merged_sd = sd + else: + merged_sd = merge_state(merged_sd, sd) + + return merged_sd + + def get_transformer_state(self, tp_index: int, pp_index: int) -> list: + assert tp_index < self.tp_degree + assert pp_index < self.pp_degree + t_list = [] + for fname_list in self.transformer_file_map[(tp_index, pp_index)]: + sd_list = [ + torch.load(fname, + map_location=torch.device('cpu')) for fname in fname_list + ] + sd = self._merge_state_dicts(sd_list) + t_list.append(sd) + return t_list + + def get_pp_transformer_map(self, pp_index: int) -> list: + assert pp_index < self.pp_degree + return self.pp_to_transformer_map[pp_index] + + def get_final_norm_state(self, tp_index: int) -> Dict: + assert tp_index in self.tp_to_final_norm_map.keys() + sd = torch.load(self.tp_to_final_norm_map[tp_index][0], + map_location=torch.device('cpu')) + return sd + + def get_final_norm_files(self, tp_index: int) -> list: + assert tp_index in self.tp_to_final_norm_map.keys() + return self.tp_to_final_norm_map[tp_index] + + def _build_tp_other_layer_map(self, layer_index: int): + assert layer_index < len(self.layer_files) + layer_files = get_files_with_prefix(self.layer_files, + self.layer_keys[layer_index]) + layer_file_partitions = partition_data(layer_files, self.tp_degree) + data_map = {i: flist for i, flist in enumerate(layer_file_partitions)} + return data_map + + def get_2d_parallel_files(self, tp_index: int, pp_index: int) -> list: + assert tp_index < self.tp_degree + assert pp_index < self.pp_degree + file_indices = self.new_2d_map.get_data(pp_index=pp_index, tp_index=tp_index) + return [self.mp_rank_files[i] for i in file_indices] + + def _build_pp_transformer_map(self): + data_map = {} + transformer_layers = self.layer_keys[1:-1] + layers_per_pp = len(transformer_layers) // self.pp_degree + data_map = { + i: transformer_layers[i * layers_per_pp:(i + 1) * layers_per_pp] + for i in range(0, + self.pp_degree) + } + return data_map + + def _dump_mapping(self, data_map, map_tag=None): + if map_tag is not None: + print(f'Dump mapping: {map_tag}') + for k, v in data_map.items(): + print(f'{k} = {v}') + + def _build_transformer_file_map(self): + transformer_layer_keys = self.layer_keys[1:-1] + file_map = {} + # XXX: this is not guaranteed + layers_per_pp = len(transformer_layer_keys) // self.pp_degree + if layers_per_pp == 0: + layers_per_pp = 1 + #print(f"{transformer_layer_keys} {layers_per_pp}") + for key_index, layer_key in enumerate(transformer_layer_keys): + pp_index = key_index // layers_per_pp + layer_files = get_files_with_prefix(self.layer_files, layer_key) + layer_file_partitions = partition_data(layer_files, self.tp_degree) + for tp_index in range(self.tp_degree): + map_key = (tp_index, pp_index) + if not map_key in file_map.keys(): + file_map[map_key] = [] + file_map[map_key].append(layer_file_partitions[tp_index]) + + return file_map + + def _sanity_check(self): + assert len(self.mp_rank_files) % self.tp_degree == 0 + assert len(self.zero_files) % (self.pp_degree * self.tp_degree) == 0 + assert len(self.layer_keys) > 2 + # XXX: fix me - isn't always the case + # only true with --pp-partition-method 'type:transformer|embedding' \ + # assert (len(self.layer_keys) - 2) % self.pp_degree == 0 + + def validate_files(self): + for file in self.file_list: + if not os.path.isfile(file): + print(f'Error: {file} is not existent') + + def _get_layer_keys(self): + key_set = set() + key_len = len(LAYER_FILE_PREFIX) + 2 + for file_path in self.layer_files: + _, fname = os.path.split(file_path) + key_set.add(fname[:key_len]) + return sorted(list(key_set)) + + def _merge_state_dicts(self, sd_list): + merged_sd = {} + for key in sd_list[0].keys(): + if not key in SEQUENTIAL_LAYERS: + cat_dim = LAYER_CONCAT_DIM.get(key, 0) + merged_sd[key] = torch.cat([sd[key] for sd in sd_list], dim=cat_dim) + else: + merged_sd[key] = sd_list[0][key] + + return merged_sd + + def _validate_folder(self, dir): + basic_folder_validation(dir) + + file_list = get_files(dir) + + for file_prefix in [ + MODEL_FILE_PREFIX, + LAYER_FILE_PREFIX, + f'{LAYER_FILE_PREFIX}01' + ]: + ckpt_files = get_files_with_prefix(file_list, file_prefix) + assert len(ckpt_files) > 0, f'{dir} seems a bogus DeepSpeed checkpoint folder: Cannot find {file_prefix}* files in there.' diff --git a/deepspeed/checkpoint/reshape_3d_utils.py b/deepspeed/checkpoint/reshape_3d_utils.py new file mode 100644 index 000000000000..b625eb222589 --- /dev/null +++ b/deepspeed/checkpoint/reshape_3d_utils.py @@ -0,0 +1,105 @@ +from .reshape_utils import (get_files, get_files_with_prefix, partition_data) + +from .constants import (ZERO_FILE_PREFIX, MODEL_FILE_PREFIX, LAYER_FILE_PREFIX) + +from .reshape_meg_2d import (reshape_meg_2d_parallel, meg_2d_parallel_map) + +PP_DIM = 'PP' +TP_DIM = 'TP' +DP_DIM = 'DP' + + +class model_3d_desc(object): + def __init__(self, pp_degree=1, tp_degree=1, dp_degree=1): + self.pp_degree = pp_degree + self.tp_degree = tp_degree + self.dp_degree = dp_degree + + def reshape(self, target_3d_desc, verbose=False): + valid_reshape, reshape_errors = self.can_reshape(target_3d_desc) + assert valid_reshape, ','.join(reshape_errors) + tgt_2d_map = reshape_meg_2d_parallel(old_pp_degree=self.pp_degree, + old_tp_degree=self.tp_degree, + new_pp_degree=target_3d_desc.pp_degree, + new_tp_degree=target_3d_desc.tp_degree, + verbose=verbose) + + flat_3d_map = flatten_dp_dimension(meg_2d_map=tgt_2d_map, + src_2d_size=self.pp_degree * self.tp_degree, + dp_degree=self.dp_degree) + + return unflatten_dp_dimension(meg_2d_map=flat_3d_map, + dp_degree=target_3d_desc.dp_degree) + + def get_desc(self): + return f'{PP_DIM},{TP_DIM},{DP_DIM} = ({self.pp_degree}, {self.tp_degree}, {self.dp_degree})' + + def is_valid(self, pp_index, tp_index, dp_index): + err_msg = [] + valid = True + for index, degree, dim_name in [ + (pp_index, self.pp_degree, PP_DIM), + (tp_index, self.tp_degree, TP_DIM), + (dp_index, self.dp_degree, DP_DIM)]: + if index >= degree: + valid = False + err_msg.append( + f'{dim_name} indexing error: index {index} >= degree {degree}') + + return valid, err_msg + + def can_reshape(self, target_3d_desc): + err_msg = [] + if target_3d_desc.pp_degree > self.pp_degree: + err_msg.append( + f'Expansion reshape not supported - {PP_DIM}: {self.pp_degree} ---> {target_3d_desc.pp_degree}' + ) + + if target_3d_desc.tp_degree > self.tp_degree: + err_msg.append( + f'Expansion reshape not supported - {TP_DIM}: {self.tp_degree} ---> {target_3d_desc.tp_degree}' + ) + + if target_3d_desc.dp_degree > self.dp_degree: + err_msg.append( + f'Expansion reshape not supported - {DP_DIM}: {self.dp_degree} ---> {target_3d_desc.dp_degree}' + ) + + return len(err_msg) == 0, err_msg + + +def get_model_3d_descriptor(dir): + file_list = get_files(dir) + tp_degree = len(get_files_with_prefix(file_list, f'{LAYER_FILE_PREFIX}01')) + pp_degree = len(get_files_with_prefix(file_list, MODEL_FILE_PREFIX)) // tp_degree + num_zero_files = len(get_files_with_prefix(file_list, ZERO_FILE_PREFIX)) + dp_degree = max(1, num_zero_files // (pp_degree * tp_degree)) + return model_3d_desc(pp_degree, tp_degree, dp_degree) + + +def flatten_dp_dimension(meg_2d_map, src_2d_size, dp_degree): + new_meg_2d_map = meg_2d_parallel_map(meg_2d_map.pp_degree, meg_2d_map.tp_degree) + for pp_index in range(meg_2d_map.pp_degree): + for tp_index in range(meg_2d_map.tp_degree): + dp0_indices = meg_2d_map.get_data(pp_index, tp_index) + for idx in dp0_indices: + dpX_indices = [idx + (i * src_2d_size) for i in range(dp_degree)] + new_meg_2d_map.add_data(pp_index, tp_index, dpX_indices) + return new_meg_2d_map + + +def unflatten_dp_dimension(meg_2d_map, dp_degree): + pp_degree = meg_2d_map.pp_degree + tp_degree = meg_2d_map.tp_degree + meg_2d_map_list = [ + meg_2d_parallel_map(pp_degree=pp_degree, + tp_degree=tp_degree) for _ in range(dp_degree) + ] + for pp_index in range(pp_degree): + for tp_index in range(tp_degree): + flat_dp_indices = meg_2d_map.get_data(pp_index, tp_index) + partitioned_dp_indices = partition_data(flat_dp_indices, dp_degree) + for dp_indices, _2d_map in zip(partitioned_dp_indices, meg_2d_map_list): + _2d_map.add_data(pp_index, tp_index, dp_indices) + + return meg_2d_map_list diff --git a/deepspeed/checkpoint/reshape_meg_2d.py b/deepspeed/checkpoint/reshape_meg_2d.py new file mode 100644 index 000000000000..0d7cd233c78a --- /dev/null +++ b/deepspeed/checkpoint/reshape_meg_2d.py @@ -0,0 +1,226 @@ +from .reshape_utils import partition_data + + +class meg_2d_parallel_map(object): + def __init__(self, pp_degree, tp_degree): + self.pp_degree = pp_degree + self.tp_degree = tp_degree + self.map = {} + + def simple_init(self): + self.map = { + self._make_key(i // self.tp_degree, + i % self.tp_degree): [i] + for i in range(self.pp_degree * self.tp_degree) + } + + def add_data(self, pp_index, tp_index, data): + self._validate_indices(pp_index, tp_index) + assert type(data) is list + + key = self._make_key(pp_index, tp_index) + if not key in self.map.keys(): + self.map[key] = [] + self.map[key] += data + + def get_data(self, pp_index=None, tp_index=None): + self._validate_indices(pp_index, tp_index) + pp_indices = list(range(self.pp_degree)) if pp_index is None else [pp_index] + tp_indices = list(range(self.tp_degree)) if tp_index is None else [tp_index] + + result = [] + for i in pp_indices: + for j in tp_indices: + result += self.map[self._make_key(i, j)] + + return result + + def print_data(self, tag): + print(f'{tag}') + for key, value in self.map.items(): + print(f'{key} = {value}') + + def _validate_indices(self, pp_index, tp_index): + assert pp_index is None or pp_index < self.pp_degree + assert tp_index is None or tp_index < self.tp_degree + + def _make_key(self, i, j): + return f'{i},{j}' + + +def _reshape_tp_dimension(old_2d_map, new_tp_degree): + old_pp_degree = old_2d_map.pp_degree + new_2d_map = meg_2d_parallel_map(old_pp_degree, new_tp_degree) + for i in range(old_pp_degree): + ranks_for_pp_index = old_2d_map.get_data(pp_index=i, tp_index=None) + split_ranks = partition_data(ranks_for_pp_index, new_tp_degree) + for j in range(new_tp_degree): + new_2d_map.add_data(i, j, split_ranks[j]) + + return new_2d_map + + +def _reshape_pp_dimension(old_2d_map, new_pp_degree): + old_tp_degree = old_2d_map.tp_degree + new_2d_map = meg_2d_parallel_map(new_pp_degree, old_tp_degree) + for i in range(old_tp_degree): + ranks_for_tp_index = old_2d_map.get_data(pp_index=None, tp_index=i) + split_ranks = partition_data(ranks_for_tp_index, new_pp_degree) + for j in range(new_pp_degree): + new_2d_map.add_data(j, i, split_ranks[j]) + + return new_2d_map + + +def reshape_meg_2d_parallel(old_pp_degree, + old_tp_degree, + new_pp_degree, + new_tp_degree, + verbose=False): + assert new_pp_degree <= old_pp_degree + assert new_tp_degree <= old_tp_degree + + old_2d_map = meg_2d_parallel_map(old_pp_degree, old_tp_degree) + old_2d_map.simple_init() + if verbose: + old_2d_map.print_data(f'original_2d_map:') + + if old_tp_degree != new_tp_degree: + new_tp_map = _reshape_tp_dimension(old_2d_map, new_tp_degree) + else: + new_tp_map = old_2d_map + if verbose: + new_tp_map.print_data(f'after_tp_reshape:') + + if old_pp_degree != new_pp_degree: + final_map = _reshape_pp_dimension(new_tp_map, new_pp_degree) + else: + final_map = new_tp_map + + if verbose: + final_map.print_data(f'final_2d_map:') + + return final_map + + +def get_mpu_ranks(tp_size=1, pp_size=1, dp_size=1, virtual_pp_size=None): + """ + Initialize model data parallel groups. + + Arguments: + tp_size: number of GPUs used to parallelize model tensor. + pp_size: number of GPUs used to parallelize model pipeline. + dp_size: number of GPUs used to parallelize model data. + + Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we + use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize + the model pipeline. The present function will + create 8 tensor model-parallel groups, 4 pipeline model-parallel groups + and 8 data-parallel groups as: + 8 data_parallel groups: + [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15] + 8 tensor model-parallel groups: + [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15] + 4 pipeline model-parallel groups: + [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15] + Note that for efficiency, the caller should make sure adjacent ranks + are on the same DGX box. For example if we are using 2 DGX-1 boxes + with a total of 16 GPUs, rank 0 to 7 belong to the first box and + ranks 8 to 15 belong to the second box. + """ + + world_size = tp_size * pp_size * dp_size + + print(f"\n\n*** tp={tp_size}, pp={pp_size}, dp={dp_size}, world={world_size}") + + tensor_model_parallel_size = min(tp_size, world_size) + pipeline_model_parallel_size = min(pp_size, world_size) + data_parallel_size = world_size // (tensor_model_parallel_size * + pipeline_model_parallel_size) + + num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size + num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size + num_data_parallel_groups = world_size // data_parallel_size + + # Build the data-parallel groups. + all_dp_group_ranks = [] + for i in range(pipeline_model_parallel_size): + start_rank = i * num_pipeline_model_parallel_groups + end_rank = (i + 1) * num_pipeline_model_parallel_groups + for j in range(tensor_model_parallel_size): + ranks = range(start_rank + j, end_rank, tensor_model_parallel_size) + all_dp_group_ranks.append(list(ranks)) + + print("DP", all_dp_group_ranks) + + # Build the model-parallel groups. + all_pp_group_ranks = [] + for i in range(data_parallel_size): + ranks = [ + data_parallel_group_ranks[i] + for data_parallel_group_ranks in all_dp_group_ranks + ] + all_pp_group_ranks.append(list(ranks)) + + print(f"PP", all_pp_group_ranks) + + # Build the tensor model-parallel groups. + all_tp_group_ranks = [] + for i in range(num_tensor_model_parallel_groups): + ranks = range(i * tensor_model_parallel_size, + (i + 1) * tensor_model_parallel_size) + all_tp_group_ranks.append(list(ranks)) + + print(f"TP", all_tp_group_ranks) + + return all_tp_group_ranks, all_pp_group_ranks, all_dp_group_ranks + + # # Build the pipeline model-parallel groups and embedding groups + # # (first and last rank in each pipeline model-parallel group). + # for i in range(num_pipeline_model_parallel_groups): + # ranks = range(i, world_size, + # num_pipeline_model_parallel_groups) + # print(f"EMB{i}", list(ranks)) + + +def reshape(src, tgt): + """ + reshape([tp_size_src, pp_size_src, dp_size_src], + [tp_size_tgt, pp_size_tgt, dp_size_tgt]) + """ + + print(f"\n\n*** Reshaping: {src} => {tgt}") + + tp_size_src, pp_size_src, dp_size_src = src + tp_size_tgt, pp_size_tgt, dp_size_tgt = tgt + + tp_ranks1, pp_ranks1, dp_ranks1 = get_mpu_ranks(tp_size=tp_size_src, pp_size=pp_size_src, dp_size=dp_size_src) + tp_ranks2, pp_ranks2, dp_ranks2 = get_mpu_ranks(tp_size=tp_size_tgt, pp_size=pp_size_src, dp_size=dp_size_src) + tp_ranks3, pp_ranks3, dp_ranks3 = get_mpu_ranks(tp_size=tp_size_tgt, pp_size=pp_size_tgt, dp_size=dp_size_src) + + # handle tp contraction first + print("\n*** TP contraction:") + + for i, r in enumerate(tp_ranks1): + print(f'{tp_ranks1[i]} => {tp_ranks2[i]}') + + # handle pp contraction next + + print("\n*** PP contraction:") + + for i, r in enumerate(pp_ranks1): + print(f'{pp_ranks2[i]} => {pp_ranks3[i]}') + + +# easy +#reshape([2,2,1],[1,1,1]) + +# probably need more logic to suggest how to pack +#reshape([4,4,1],[2,2,1]) + +#reshape([2,4,2], [8,32,1]) + +# get_mpu_ranks(2,2,2) +# get_mpu_ranks(4,2,1) +# get_mpu_ranks(2,4,1) +# get_mpu_ranks(1,1,8) diff --git a/deepspeed/checkpoint/reshape_utils.py b/deepspeed/checkpoint/reshape_utils.py new file mode 100644 index 000000000000..5c3a687967be --- /dev/null +++ b/deepspeed/checkpoint/reshape_utils.py @@ -0,0 +1,87 @@ +import os +import torch +from collections import OrderedDict + + +def basic_folder_validation(dir): + assert os.path.exists(dir), f'{dir} path does not exist' + assert os.path.isdir(dir), f'{dir} is not a folder' + + +def get_files_with_prefix(all_files, prefix): + file_list = [] + for file_path in all_files: + _, fname = os.path.split(file_path) + if fname.startswith(prefix): + file_list.append(file_path) + + return sorted(file_list) + + +def validate_files(file_list): + for file in file_list: + if not os.path.isfile(file): + print(f'Error: {file} is not existent') + + +def get_files(dir): + file_list = [] + for root, _, files in os.walk(dir): + for file in files: + file_list.append(os.path.join(root, file)) + return file_list + + +def partition_data(data_list, num_partitions): + num_elems = len(data_list) + assert num_elems % num_partitions == 0 + partition_size = num_elems // num_partitions + partitions_list = [ + data_list[i:i + partition_size] for i in range(0, + num_elems, + partition_size) + ] + return partitions_list + + +def _key_list_to_string(key_list): + return '.'.join(key_list) + + +def merge_state_dict(dict_a, dict_b, key_list): + merged_dict = type(dict_a)({}) + + for key, value in dict_b.items(): + if key in dict_a.keys(): + merged_dict[key] = merge_state(dict_a[key], dict_b[key], [str(key)]) + else: + merged_dict[key] = value + + return merged_dict + + +def merge_state_list(list_a, list_b, key_list): + if len(list_a) != len(list_b): + print(f'{_key_list_to_string(key_list)}') + raise ValueError( + f'Cannot merge lists of different lengths, a = {len(list_a)} b = {len(list_b)}' + ) + + return [merge_state(a, b, key_list) for a, b in zip(list_a, list_b)] + + +def merge_state(state_a, state_b, key_list=[]): + if type(state_a) != type(state_b): + key_list_string = _key_list_to_string(key_list) + print(f'key_list = {key_list_string}') + raise ValueError( + f'Cannot merge two states of types {type(state_a)} and type {type(state_b)}') + + if type(state_a) in (dict, OrderedDict): + return merge_state_dict(state_a, state_b, key_list) + elif type(state_a) in (list, tuple): + return type(state_a)(merge_state_list(state_a, state_b, key_list)) + elif torch.is_tensor(state_a): + return torch.cat([state_a, state_b], 0) + else: + return state_a diff --git a/deepspeed/checkpoint/utils.py b/deepspeed/checkpoint/utils.py new file mode 100644 index 000000000000..cb25e524a201 --- /dev/null +++ b/deepspeed/checkpoint/utils.py @@ -0,0 +1,29 @@ +import os +from .constants import (MODEL_FILE_PREFIX, + MODEL_FILE_SUFFIX, + OPTIM_FILE_SUFFIX, + ZERO_FILE_PREFIX) + + +def get_model_ckpt_name_for_rank(base_folder, mp_rank_str): + ckpt_name = os.path.join( + base_folder, + MODEL_FILE_PREFIX + mp_rank_str + MODEL_FILE_SUFFIX, + ) + return ckpt_name + + +def get_zero_ckpt_name_for_rank(base_folder, dp_rank, mp_rank): + zero_prefix = f'{ZERO_FILE_PREFIX}{dp_rank}' + mp_rank_string = f'_{MODEL_FILE_PREFIX}{mp_rank:02d}' + zero_ckpt_name = os.path.join( + base_folder, + zero_prefix + mp_rank_string + OPTIM_FILE_SUFFIX, + ) + return zero_ckpt_name + + +def get_layer_ckpt_name_for_rank(base_folder, layer_id, tp_rank): + ckpt_file = f'{layer_id}-model_{tp_rank:02d}{MODEL_FILE_SUFFIX}' + ckpt_path = os.path.join(base_folder, ckpt_file) + return ckpt_path diff --git a/deepspeed/checkpoint/zero_checkpoint.py b/deepspeed/checkpoint/zero_checkpoint.py new file mode 100644 index 000000000000..01a6ebe9c1d9 --- /dev/null +++ b/deepspeed/checkpoint/zero_checkpoint.py @@ -0,0 +1,146 @@ +import torch + +from .constants import (BASE_OPTIMIZER_STATE, + GROUP_PADDINGS, + OPTIMIZER_STATE_DICT, + PARTITION_COUNT, + ZERO_FILE_PREFIX, + BF16_ZERO_FILE_PREFIX) + +from .reshape_utils import (basic_folder_validation, + get_files, + get_files_with_prefix, + merge_state) + +from .reshape_3d_utils import (model_3d_desc, get_model_3d_descriptor) + +GROUP_STATE_KEY = 'state' + + +class ZeROCheckpoint(object): + def __init__(self, dir): + basic_folder_validation(dir) + self.dir = dir + self.file_list = self._get_zero_files(dir) + self.num_files = len(self.file_list) + assert self.num_files > 0, f'No ZeRO files found in {dir}' + + self.src_3d = get_model_3d_descriptor(dir) + self.target_3d = model_3d_desc(pp_degree=self.src_3d.pp_degree, + tp_degree=self.src_3d.tp_degree, + dp_degree=self.src_3d.dp_degree) + self._3d_file_map = self.src_3d.reshape(self.target_3d) + + def get_file_indices_for_rank(self, pp_index, tp_index, dp_index): + assert dp_index < len(self._3d_file_map), f'DP index {dp_index} >= DP degree {len(self._3d_file_map)}' + dp_2d_map = self._3d_file_map[dp_index] + return dp_2d_map.get_data(pp_index, tp_index) + + def get_files_for_rank(self, pp_index, tp_index, dp_index): + file_idx_list = self.get_file_indices_for_rank(pp_index, tp_index, dp_index) + return [self.file_list[idx] for idx in file_idx_list] + + def get_state_for_rank(self, + pp_index, + tp_index, + dp_index, + keys_to_ignore=[], + strip_tensor_paddings=True): + state_file_list = self.get_files_for_rank(pp_index, tp_index, dp_index) + merged_sd = None + for state_file in state_file_list: + sd = torch.load(state_file, map_location=torch.device('cpu')) + for key in keys_to_ignore: + sd.pop(key, None) + + if strip_tensor_paddings: + self._strip_tensor_paddings(sd) + + if merged_sd is None: + merged_sd = sd + else: + merged_sd = merge_state(merged_sd, sd) + + self._update_partition_count(merged_sd) + if strip_tensor_paddings: + self._clear_group_paddings(merged_sd) + + return merged_sd + + def print_3d_index_map(self, tag=None): + if tag: + print(f'3D index map: {tag}') + for dp_index, _2d_map in enumerate(self._3d_file_map): + _2d_map.print_data(f'dp = {dp_index}') + + def print_3d_file_map(self, tag=None): + if tag: + print(f'3D file map: {tag}') + for dp_index, _2d_map in enumerate(self._3d_file_map): + for pp_index in _2d_map.pp_degree: + for tp_index in _2d_map.tp_degree: + file_index_list = _2d_map.get_data(pp_index, tp_index) + file_list = [self.file_list[idx] for idx in file_index_list] + print(f'{pp_index}, {tp_index}, {dp_index} => {file_list}') + + def reshape(self, target_3d_desc: model_3d_desc): + self.target_3d = target_3d_desc + self._3d_file_map = self.src_3d.reshape(self.target_3d) + + def _strip_tensor_paddings(self, sd): + param_group_states = self._get_param_group_states(sd) + if param_group_states is None: + return + + group_paddings = self._get_optimizer_state(sd, GROUP_PADDINGS) + if group_paddings is None: + return + + for key, group_state in param_group_states.items(): + if group_paddings[key] == 0: + continue + for state_name, state_value in group_state.items(): + if torch.is_tensor(state_value): + raw_length = state_value.numel() - group_paddings[key] + group_state[state_name] = torch.narrow(state_value, + 0, + 0, + raw_length).clone() + + def _clear_group_paddings(self, sd): + group_paddings = self._get_optimizer_state(sd, GROUP_PADDINGS) + if group_paddings: + num_groups = len(group_paddings) + sd[OPTIMIZER_STATE_DICT][GROUP_PADDINGS] = [0] * num_groups + + def _get_optimizer_state(self, sd, state_key): + optimizer_state = sd.get(OPTIMIZER_STATE_DICT, None) + if optimizer_state is None: + return None + + return optimizer_state.get(state_key, None) + + def _get_param_group_states(self, sd): + optimizer_state = sd.get(OPTIMIZER_STATE_DICT, None) + if optimizer_state is None: + return None + + base_optimizer_state = optimizer_state.get(BASE_OPTIMIZER_STATE, None) + if base_optimizer_state is None: + return None + + return base_optimizer_state.get(GROUP_STATE_KEY, None) + + def _update_partition_count(self, sd): + partition_counts = self._get_optimizer_state(sd, PARTITION_COUNT) + if partition_counts: + num_groups = len(partition_counts) + sd[OPTIMIZER_STATE_DICT][PARTITION_COUNT] = [self.target_3d.dp_degree + ] * num_groups + + def _get_zero_files(self, dir): + file_list = get_files(dir) + zero_files = get_files_with_prefix(file_list, ZERO_FILE_PREFIX) + if len(zero_files) > 0: + return zero_files + return get_files_with_prefix(file_list, BF16_ZERO_FILE_PREFIX) diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 9dd3bdd4e329..d5ffbac9d1d3 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -1,4 +1,10 @@ +""" +Copyright 2022 The Microsoft DeepSpeed Team +""" + +from typing import OrderedDict import torch +import os from deepspeed import comm as dist from deepspeed.runtime.constants import PIPE_REPLICATED from deepspeed.ops.op_builder import UtilsBuilder @@ -6,6 +12,7 @@ from packaging import version as pkg_version from deepspeed.git_version_info import version +from deepspeed.runtime.swap_tensor.partitioned_param_swapper import print_rank_0 from deepspeed.runtime.utils import (get_global_norm_of_tensors, clip_tensors_by_global_norm, DummyOptim, @@ -20,7 +27,9 @@ BASE_OPTIMIZER_STATE, SINGLE_PARTITION_OF_FP32_GROUPS, CLIP_GRAD, - GROUPS_PADDING) + GROUP_PADDINGS, + PARAM_SLICE_MAPPINGS, + FP32_WEIGHT_KEY) import types @@ -53,6 +62,12 @@ def get_optim_state_fragment(self, key): else: raise ValueError(f'{key} not found in optimizer state fragment') + def get_hp_fragment_address(self): + return self.hp_fragment_address + + def get_optim_state_keys(self): + return list(self.optim_fragment.keys()) + def get_full_hp_param(self, optim_state_key=None): reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten() @@ -72,9 +87,103 @@ def get_full_hp_param(self, optim_state_key=None): return reduce_buffer.reshape_as(self) +def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size): + hp_mapping = self._hp_mapping + optim_state_keys = hp_mapping.get_optim_state_keys() + hp_keys = [FP32_WEIGHT_KEY] + optim_state_keys + checkpoint_files = {key: os.path.join(folder, f"{key}.pt") for key in hp_keys} + + for file in checkpoint_files.values(): + assert os.path.isfile(file), f'{file} is not a valid file' + + for key in hp_keys: + ckpt_file = checkpoint_files[key] + ckpt_dict = torch.load(ckpt_file) + full_hp_param = ckpt_dict['param'] + + # need to deal with slices that were averaged. + # the opposite of averaging here becomes an exact copy of the first slice + # I thought of 2 ways: + # implementation a. find a way for a client to pass a dict with patterns + # if any(re.search(pattern, folder) for pattern in WEIGHTS_TO_AVERAGE_PATTERNS): + # tp_rank = 0 + # tp_world_size = 1 + # the other approach is to assume that the saved data is correct and if full_hp_param.shape == + # self.shape that means we automatically copy? + # implementation b. + # this version requires no additional data passed from the client + # if the shapes already match it must be slices that were averaged - so we just hack around those + if full_hp_param.shape == self.shape: + tp_rank = 0 + tp_world_size = 1 + + # special case for word_embeddings weights which get padded differently depending on TP degree. + # the converter to universal currently strips the original padding completely so the saved + # weight is padding-free and we just need to add new padding depending on the target TP + # degree + vocab_divisibility_padding_tensor = ckpt_dict.get( + 'vocab_divisibility_padding_tensor', + None) + if vocab_divisibility_padding_tensor is not None: + # In the absence of data passed from the user wrt new padded vocab specific to tp degree + # we can again derive that data by reverse engineering the target shapes like so: + padded_target_vocab_size = self.shape[0] * tp_world_size + if padded_target_vocab_size > full_hp_param.shape[0]: + # Need to expand + padding_tensor = vocab_divisibility_padding_tensor.expand( + padded_target_vocab_size - full_hp_param.shape[0]) + # Implement the following concat in efficient way using pad + #full_hp_param = torch.cat((full_hp_param, padding_tensor), 0) + full_hp_param = torch.nn.functional.pad(full_hp_param, + (0, + 0, + 0, + padding_tensor.shape[0]), + "constant", + 0) + full_hp_param[:-padding_tensor.shape[0], :] = padding_tensor + else: + # Need to shrink or keep the same + full_hp_param = full_hp_param[:padded_target_vocab_size, :] + + full_param_numel = full_hp_param.numel() + tp_slice_numel = self.numel() + # if key == FP32_WEIGHT_KEY and 'word_embeddings.weight' in folder: + # print_rank_0(f'{full_hp_param[:10]=}', force=True) + + + assert full_param_numel == tp_world_size * tp_slice_numel, \ + f'Loading {ckpt_file} full param numel {full_param_numel} != tensor slice numel {tp_slice_numel} * tp_world_size {tp_world_size}' + dst_tensor = hp_mapping.hp_fragment if key == FP32_WEIGHT_KEY else hp_mapping.get_optim_state_fragment( + key) + + # print(f"{full_hp_param.shape=} {full_param_numel=} {folder=}") + # print(f"{dst_tensor.shape=} {dst_tensor.numel()=}{folder=}") + + # since when we do many to 1 on tp we cat sometimes on dim=0 and other times on dim=1 we have to do exactly the same in reverse + chunk_dim = ckpt_dict.get('cat_dim', 0) + + # this performs the opposite of cat when merging TP slices + tp_hp_slice = full_hp_param.chunk(tp_world_size, chunk_dim)[tp_rank] + tp_hp_slice = tp_hp_slice.flatten() + + lp_frag_address = hp_mapping.lp_fragment_address + tp_hp_fragment = tp_hp_slice.narrow(0, + lp_frag_address.start, + lp_frag_address.numel) + assert dst_tensor.numel() == lp_frag_address.numel, \ + f'Load checkpoint {key} dst_tensor numel {dst_tensor.numel()} != src numel {lp_frag_address.numel}' + + # print(f"{key} SHAPE: {tp_hp_slice.shape=}") + # print(f"{key} SHAPE: {dst_tensor.shape=}") + # print(f"{key} SHAPE: {tp_hp_fragment.shape=}") + dst_tensor.data.copy_(tp_hp_fragment.data) + + class BF16_Optimizer(ZeROOptimizer): def __init__(self, init_optimizer, + param_names, mpu=None, clip_grad=0.0, norm_type=2, @@ -85,6 +194,7 @@ def __init__(self, see_memory_usage('begin bf16_optimizer', force=True) self.timers = timers self.optimizer = init_optimizer + self.param_names = param_names self.using_real_optimizer = not isinstance(self.optimizer, DummyOptim) self.clip_grad = clip_grad @@ -120,7 +230,7 @@ def __init__(self, self.fp32_groups_has_gradients = [] self.step_count = 0 - self.groups_padding = [] + self.group_paddings = [] if self.using_real_optimizer: self._setup_for_real_optimizer() @@ -205,7 +315,7 @@ def _setup_for_real_optimizer(self): else: padding = 0 - self.groups_padding.append(padding) + self.group_paddings.append(padding) # update optimizer param groups to reference fp32 params partition param_group['params'] = [self.fp32_groups_flat_partition[i]] @@ -218,12 +328,25 @@ def _setup_for_real_optimizer(self): # Need optimizer states initialized before linking lp to optimizer state self._link_all_hp_params() + self._param_slice_mappings = self._create_param_mapping() + + def _create_param_mapping(self): + param_mapping = [] + for i, _ in enumerate(self.optimizer.param_groups): + param_mapping_per_group = OrderedDict() + for lp in self.bf16_groups[i]: + if lp._hp_mapping is not None: + lp_name = self.param_names[lp] + param_mapping_per_group[ + lp_name] = lp._hp_mapping.get_hp_fragment_address() + param_mapping.append(param_mapping_per_group) + + return param_mapping def _link_all_hp_params(self): dp_world_size = dist.get_world_size(group=self.dp_process_group) for i, param_group in enumerate(self.optimizer.param_groups): # Link bf16 and fp32 params in partition - # TODO: Make this configurable partition_id = dist.get_rank(group=self.real_dp_process_group[i]) partition_size = self.bf16_groups_flat[i].numel() // dp_world_size self._link_hp_params(self.bf16_groups[i], @@ -244,6 +367,9 @@ def _init_lp_to_hp_mapping(self, lp_param._hp_mapping = None lp_param._dp_group = dp_group lp_param.get_full_hp_param = types.MethodType(get_full_hp_param, lp_param) + lp_param.load_hp_checkpoint_state = types.MethodType( + load_hp_checkpoint_state, + lp_param) # lp_param overlaps with partition if both are true # 1) current_offset < partition_end, # 2) current_offset + lp_param.numel() >= partition_start @@ -363,11 +489,6 @@ def step(self, closure=None): self.update_lp_params() - all_gather_dp_groups(partitioned_param_groups=self.bf16_partitioned_groups, - dp_process_group=self.real_dp_process_group, - start_alignment_factor=self.nccl_start_alignment_factor, - allgather_bucket_size=self.allgather_bucket_size) - self.clear_hp_grads() self.step_count += 1 @@ -434,6 +555,14 @@ def update_lp_params(self): for i, (bf16_partitions, fp32_partition) in enumerate(zip(self.bf16_partitioned_groups, self.fp32_groups_flat_partition)): partition_id = dist.get_rank(group=self.real_dp_process_group[i]) bf16_partitions[partition_id].data.copy_(fp32_partition.data) + # print_rank_0(f'update_lp_params {i=} {partition_id=}', force=True) + # if i == 0: + # print_rank_0(f'{fp32_partition[:10]=}', force=True) + + all_gather_dp_groups(partitioned_param_groups=self.bf16_partitioned_groups, + dp_process_group=self.real_dp_process_group, + start_alignment_factor=self.nccl_start_alignment_factor, + allgather_bucket_size=self.allgather_bucket_size) def clear_hp_grads(self): for flat_gradients in self.fp32_groups_gradients_flat: @@ -452,9 +581,10 @@ def state_dict(self): state_dict[CLIP_GRAD] = self.clip_grad state_dict[BASE_OPTIMIZER_STATE] = self.optimizer.state_dict() state_dict[SINGLE_PARTITION_OF_FP32_GROUPS] = self.fp32_groups_flat_partition - state_dict[GROUPS_PADDING] = self.groups_padding + state_dict[GROUP_PADDINGS] = self.group_paddings state_dict[PARTITION_COUNT] = self.partition_count state_dict[DS_VERSION] = version + state_dict[PARAM_SLICE_MAPPINGS] = self._param_slice_mappings return state_dict @@ -470,8 +600,23 @@ def refresh_fp32_params(self): def load_state_dict(self, state_dict_list, + checkpoint_folder, load_optimizer_states=True, load_from_fp32_weights=False): + if checkpoint_folder: + self._load_universal_checkpoint(checkpoint_folder, + load_optimizer_states, + load_from_fp32_weights) + else: + self._load_legacy_checkpoint(state_dict_list, + load_optimizer_states, + load_from_fp32_weights) + + def _load_legacy_checkpoint(self, + state_dict_list, + load_optimizer_states=True, + load_from_fp32_weights=False): + dp_rank = dist.get_rank(group=self.dp_process_group) current_rank_sd = state_dict_list[dp_rank] @@ -492,11 +637,32 @@ def load_state_dict(self, if load_optimizer_states: self._link_all_hp_params() + def _load_universal_checkpoint(self, + checkpoint_folder, + load_optimizer_states, + load_from_fp32_weights): + self._load_hp_checkpoint_state(checkpoint_folder) + @property def param_groups(self): """Forward the wrapped optimizer's parameters.""" return self.optimizer.param_groups + def _load_hp_checkpoint_state(self, checkpoint_dir): + checkpoint_dir = os.path.join(checkpoint_dir, "zero") + tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu) + tp_world_size = self.mpu.get_slice_parallel_world_size() + + for i, _ in enumerate(self.optimizer.param_groups): + for lp in self.bf16_groups[i]: + if lp._hp_mapping is not None: + #print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}") + lp.load_hp_checkpoint_state( + os.path.join(checkpoint_dir, + self.param_names[lp]), + tp_rank, + tp_world_size) + def _get_padded_tensor(src_tensor, size): if src_tensor.numel() >= size: diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index 35d50d00f36b..9e68ac60ec02 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -867,6 +867,9 @@ def _initialize_params(self, param_dict): self.checkpoint_tag_validation_enabled = (validation_mode != ValidationMode.IGNORE) self.checkpoint_tag_validation_fail = validation_mode == ValidationMode.FAIL + self.load_universal_checkpoint = checkpoint_params.get( + LOAD_UNIVERSAL_CHECKPOINT, + LOAD_UNIVERSAL_CHECKPOINT_DEFAULT) self.aio_config = get_aio_config(param_dict) diff --git a/deepspeed/runtime/constants.py b/deepspeed/runtime/constants.py index 250240089f89..2ef10161f042 100755 --- a/deepspeed/runtime/constants.py +++ b/deepspeed/runtime/constants.py @@ -360,7 +360,10 @@ class ValidationMode: ######################################### # Checkpoint config params ######################################### -# "checkpoint": {tag_validation=["Ignore"|"Warn"|"Fail"]} +# "checkpoint": { +# tag_validation=["Ignore"|"Warn"|"Fail"] +# load_universal=false +# } CHECKPOINT = "checkpoint" CHECKPOINT_TAG_VALIDATION = "tag_validation" CHECKPOINT_TAG_VALIDATION_DEFAULT = ValidationMode.WARN @@ -370,6 +373,9 @@ class ValidationMode: ValidationMode.FAIL ] +LOAD_UNIVERSAL_CHECKPOINT = "load_universal" +LOAD_UNIVERSAL_CHECKPOINT_DEFAULT = False + ######################################### # Drop the last incomplete Batch # ######################################### diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 51ed1b8e5884..07638b33033f 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -18,6 +18,7 @@ from torch.optim.lr_scheduler import _LRScheduler from typing import Callable, Dict, Optional, Union, Iterable +from deepspeed.checkpoint.utils import get_zero_ckpt_name_for_rank import deepspeed @@ -724,6 +725,9 @@ def loss_scale(self): def gradient_accumulation_steps(self): return self._config.gradient_accumulation_steps + def load_universal_checkpoint(self): + return self._config.load_universal_checkpoint + @property def communication_data_type(self): res = self._config.communication_data_type @@ -1329,6 +1333,7 @@ def _configure_bf16_optimizer(self, optimizer): timers = self.timers if self.wall_clock_breakdown() else None optimizer = BF16_Optimizer( optimizer, + self.param_names, mpu=self.mpu, clip_grad=clip_grad, allgather_bucket_size=self.zero_allgather_bucket_size(), @@ -2499,16 +2504,23 @@ def load_checkpoint(self, """ if tag is None: - latest_path = os.path.join(load_dir, "latest") + latest_tag = "latest_universal" if self.load_universal_checkpoint( + ) else "latest" + latest_path = os.path.join(load_dir, latest_tag) if os.path.isfile(latest_path): with open(latest_path, "r") as fd: tag = fd.read().strip() else: - logger.warning( - f"Unable to find latest file at {latest_path}, if trying to load latest " - "checkpoint please ensure this file exists or pass an explicit checkpoint tag when loading a checkpoint." - ) - return None, None + if self.load_universal_checkpoint(): + raise ValueError( + f'Invalid for universal checkpoint: {latest_path} does not exist' + ) + else: + logger.warning( + f"Unable to find latest file at {latest_path}, if trying to load latest " + "checkpoint please ensure this file exists or pass an explicit checkpoint tag when loading a checkpoint." + ) + return None, None if self.zero_optimization_partition_weights(): # Prepare for checkpoint load by ensuring all parameters are partitioned @@ -2576,10 +2588,10 @@ def _load_checkpoint(self, model=self.module, mpu=self.mpu, num_experts=self.num_experts) - - self.load_module_state_dict(state_dict=checkpoint['module'], - strict=load_module_strict, - custom_load_fn=custom_load_fn) + if not self.load_universal_checkpoint(): + self.load_module_state_dict(state_dict=checkpoint['module'], + strict=load_module_strict, + custom_load_fn=custom_load_fn) self.loaded_checkpoint_dp_world_size = checkpoint['dp_world_size'] @@ -2677,25 +2689,35 @@ def get_sparse_tensor_module_names(original_set, return load_path, client_state def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True): - zero_sd_list = self._get_all_zero_checkpoints(load_dir, tag) - if zero_sd_list is None: - return False - - if load_optimizer_states and self.dp_world_size != self.loaded_checkpoint_dp_world_size: - raise ZeRORuntimeException("The checkpoint being loaded used a DP " \ - f"world size of {self.loaded_checkpoint_dp_world_size} but the " \ - f"current world size is {self.dp_world_size}. Automatic adjustment " \ - "of ZeRO's optimizer state partitioning with a new world size is not " \ - "currently supported.") + if self.load_universal_checkpoint(): + zero_sd_list = None + checkpoint_folder = f'{os.path.join(load_dir, tag)}' + else: + if load_optimizer_states and self.dp_world_size != self.loaded_checkpoint_dp_world_size: + raise ZeRORuntimeException("The checkpoint being loaded used a DP " \ + f"world size of {self.loaded_checkpoint_dp_world_size} but the " \ + f"current world size is {self.dp_world_size}. Automatic adjustment " \ + "of ZeRO's optimizer state partitioning with a new world size is not " \ + "currently supported.") + checkpoint_folder = None + zero_sd_list = self._get_all_zero_checkpoints(load_dir, tag) + if zero_sd_list is None: + return False self.optimizer.load_state_dict( state_dict_list=zero_sd_list, load_optimizer_states=load_optimizer_states, load_from_fp32_weights=self.zero_load_from_fp32_weights(), - ) - logger.info( - f"loading {len(zero_sd_list)} zero partition checkpoints for rank {self.global_rank}" - ) + checkpoint_folder=checkpoint_folder) + + if self.load_universal_checkpoint(): + logger.info( + f'loaded universal zero checkpoints from {checkpoint_folder} for rank {self.global_rank}' + ) + else: + logger.info( + f"loading {len(zero_sd_list)} zero partition checkpoints for rank {self.global_rank}" + ) return True def _get_mp_rank_zero_checkpoint_names(self, diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index e4afdfdbfb0c..6f7185413d88 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -2357,7 +2357,8 @@ def _rigid_load_state_dict(self, state_dict, load_optimizer_states=True): def load_state_dict(self, state_dict_list, load_optimizer_states=True, - load_from_fp32_weights=False): + load_from_fp32_weights=False, + checkpoint_folder=None): r"""Loading a ZeRO checkpoint Arguments: state_dict_list: List of all saved ZeRO checkpoints, one for each saved partition. diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index c019bd0e3647..c36c17dc02e4 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -16,15 +16,17 @@ align_dense_tensors, all_gather_dp_groups) -from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_GRADIENTS +from deepspeed.runtime.zero.constants import ZERO_OPTIMIZATION_GRADIENTS, ZERO_OPTIMIZATION_OPTIMIZER_STATES from deepspeed.runtime.zero.offload_constants import OFFLOAD_CPU_DEVICE, OFFLOAD_OPTIMIZER from deepspeed.ops.adam import DeepSpeedCPUAdam from deepspeed.ops.op_builder import UtilsBuilder from deepspeed.utils import logger from deepspeed.moe.utils import is_moe_param from deepspeed.git_version_info import version + from deepspeed.runtime.constants import PIPE_REPLICATED from deepspeed.checkpoint.constants import (DS_VERSION, + GROUP_PADDINGS, PARTITION_COUNT, SINGLE_PARTITION_OF_FP32_GROUPS, BASE_OPTIMIZER_STATE, @@ -278,15 +280,6 @@ def __init__(self, ] self.bit16_groups.append(trainable_parameters) - # Record padding required to align group to world size - if partition_id == dist.get_world_size( - group=self.real_dp_process_group[i]) - 1: - padding = get_alignment_padding(self.bit16_groups[i], - self.partition_count[i]) - else: - padding = 0 - self.groups_padding.append(padding) - # not sure why apex was cloning the weights before flattening # removing cloning here @@ -321,6 +314,15 @@ def __init__(self, see_memory_usage(f"After flattening and moving param group {i} to GPU", force=False) + # Record padding required for alignment + if partition_id == dist.get_world_size( + group=self.real_dp_process_group[i]) - 1: + padding = self.bit16_groups_flat[i].numel() - sum( + [t.numel() for t in self.round_robin_bit16_groups[i]]) + else: + padding = 0 + self.groups_padding.append(padding) + if dist.get_rank(group=self.real_dp_process_group[i]) == 0: see_memory_usage( f"After Flattening and after emptying param group {i} cache", @@ -2045,7 +2047,9 @@ def state_dict(self): self.single_partition_of_fp32_groups) state_dict[SINGLE_PARTITION_OF_FP32_GROUPS] = fp32_groups_without_padding - state_dict[ZERO_STAGE] = ZERO_OPTIMIZATION_GRADIENTS + state_dict[ + ZERO_STAGE] = ZERO_OPTIMIZATION_GRADIENTS if self.partition_gradients else ZERO_OPTIMIZATION_OPTIMIZER_STATES + state_dict[GROUP_PADDINGS] = self.groups_padding state_dict[PARTITION_COUNT] = self.partition_count state_dict[DS_VERSION] = version @@ -2158,7 +2162,8 @@ def _restore_elastic_base_optimizer_state(self, all_state_dict): def load_state_dict(self, state_dict_list, load_optimizer_states=True, - load_from_fp32_weights=False): + load_from_fp32_weights=False, + checkpoint_folder=None): r"""Loading ZeRO checkpoint Arguments: @@ -2208,6 +2213,16 @@ def load_state_dict(self, ckpt_is_rigid = isinstance(current_rank_sd[BASE_OPTIMIZER_STATE], dict) + # padding is always at the last rank/partition + # if DP=1024 and param-group elems=16 -> padding will be 1024-16 across all but one rank + # scenario-1 (shrink): saving w. 4 gpus -> loading w. 2 gpus + # scenario-2 (expand): saving w. 2 gpus -> loading w. 4 gpus + # if load_optimizer_states: + # if new_dp_size: + # self.strip_padding() + # self.add_padding_w_new_dp_size() + # self.optimizer.load_state_dict(current_rank_sd[BASE_OPTIMIZER_STATE]) + if load_optimizer_states: if ckpt_is_rigid: # loading rigid ckpt into either rigid or elastic exec diff --git a/tests/unit/test_checkpointing.py b/tests/unit/test_checkpointing.py index ddac8a3dcd02..dd93e006081f 100755 --- a/tests/unit/test_checkpointing.py +++ b/tests/unit/test_checkpointing.py @@ -16,7 +16,7 @@ from deepspeed.ops.op_builder import FusedLambBuilder, CPUAdamBuilder from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3 -from .util import required_torch_version +from .util import required_minimum_torch_version, required_torch_version import itertools import argparse @@ -88,18 +88,25 @@ def compare_model_states(saved_model, assert False, f'Unexpected Optimizer Type: {saved_model.optimizer}' +def _compare_state_dicts(state0, state1, expected_mismatch_keys=[]): + for (k0, s0), (k1, s1) in zip(state0.items(), state1.items()): + assert k0 == k1, f'failure due to key mismatch {k0} != {k1}' + if k0 in expected_mismatch_keys: + continue + if isinstance(s0, torch.Tensor) and isinstance(s1, torch.Tensor): + assert id(s0) != id(s1), f'Comparing optimizer state tensor against itself: {id(s0)} <====> {id(s1)}' + assert torch.equal(s0.to('cpu'), s1.to('cpu')) + else: + assert s0 == s1, f'failures with keys = {k0}, {k1}, values = {type(s0[0])} and {type(s1[0])}' + + def compare_optimizer_states(saved_model, loaded_model, hidden_dim, fp16=True): saved_optimizer = saved_model.optimizer.optimizer if fp16 else saved_model.optimizer loaded_optimizer = loaded_model.optimizer.optimizer if fp16 else loaded_model.optimizer for state0, state1 in zip(saved_optimizer.state.values(), loaded_optimizer.state.values()): - for s0, s1 in zip(state0.values(), state1.values()): - if isinstance(s0, torch.Tensor) and isinstance(s1, torch.Tensor): - assert id(s0) != id(s1), f'Comparing optimizer state tensor against itself: {id(s0)} <====> {id(s1)}' - assert torch.equal(s0, s1) - else: - assert s0 == s1 + _compare_state_dicts(state0, state1) def compare_lr_scheduler_states(saved_model, loaded_model): @@ -1178,6 +1185,11 @@ def test_checkpoint_zero_elastic(tmpdir, elastic_save, elastic_load, load_optim) @distributed_test(world_size=[2]) def _go(): + # torch 1.2.* stores raw tensor id numbers in checkpoint state which leads to + # false positive mismatches in checkpoint state comparisons. + # Newer torch versions store tensor ids as 0, 1, 2, ... + expected_mismatch_keys = [] if required_minimum_torch_version(1, + 4) else ['params'] models = [SimpleModel(hidden_dim) for _ in range(2)] model, _, _, _ = deepspeed.initialize(config=ds_config, model=models[0], @@ -1190,6 +1202,10 @@ def _go(): loss = model(batch[0], batch[1]) model.backward(loss) model.step() + if load_optim: + torch.save(model.optimizer.optimizer.state_dict(), + os.path.join(tmpdir, + 'opt-state-dict')) model.save_checkpoint(tmpdir) ds_config["zero_optimization"]["elastic_checkpoint"] = elastic_load @@ -1197,6 +1213,15 @@ def _go(): model=models[1], model_parameters=models[1].parameters()) model.load_checkpoint(tmpdir, load_optimizer_states=load_optim) + + if load_optim: + saved_sd = torch.load(os.path.join(tmpdir, 'opt-state-dict')) + curr_sd = model.optimizer.optimizer.state_dict() + for curr_param_group, saved_param_group in zip(curr_sd['param_groups'], saved_sd['param_groups']): + _compare_state_dicts(curr_param_group, + saved_param_group, + expected_mismatch_keys) + data_loader = random_dataloader(model=model, total_samples=8, hidden_dim=hidden_dim, @@ -1252,6 +1277,11 @@ def _go2(models): loss = model(batch[0], batch[1]) model.backward(loss) model.step() + + if load_optim: + torch.save(model.optimizer.optimizer.state_dict(), + os.path.join(tmpdir, + 'opt-state-dict')) model.save_checkpoint(tmpdir) _go2(models) @@ -1260,8 +1290,8 @@ def _go2(models): def _go1(models): ds_config["zero_optimization"]["elastic_checkpoint"] = elastic_load model, _, _, _ = deepspeed.initialize(config=ds_config, - model=models[1], - model_parameters=models[1].parameters()) + model=models[1], + model_parameters=models[1].parameters()) if load_optim: with pytest.raises(deepspeed.runtime.zero.utils.ZeRORuntimeException): model.load_checkpoint(tmpdir, load_optimizer_states=load_optim) diff --git a/tests/unit/test_reshape_checkpoint.py b/tests/unit/test_reshape_checkpoint.py new file mode 100644 index 000000000000..317f3bb1661f --- /dev/null +++ b/tests/unit/test_reshape_checkpoint.py @@ -0,0 +1,58 @@ +import pytest +import deepspeed + +from deepspeed.checkpoint import model_3d_desc + + +def _do_reshape(src_3d, tgt_3d): + assert src_3d.can_reshape(tgt_3d) + new_3d_map = src_3d.reshape(tgt_3d) + + assert len(new_3d_map) == tgt_3d.dp_degree + for new_2d_map in new_3d_map: + assert new_2d_map.pp_degree == tgt_3d.pp_degree + assert new_2d_map.tp_degree == tgt_3d.tp_degree + + return new_3d_map + + +# Specify 3d shape as pp/tp/dp +def test_reshape_222_to_111(): + src_3d = model_3d_desc(pp_degree=2, tp_degree=2, dp_degree=2) + tgt_3d = model_3d_desc(pp_degree=1, tp_degree=1, dp_degree=1) + + new_3d_map = _do_reshape(src_3d, tgt_3d) + + assert new_3d_map[0].get_data(pp_index=0, tp_index=0) == [0, 4, 1, 5, 2, 6, 3, 7] + + +def test_reshape_222_to_121(): + src_3d = model_3d_desc(pp_degree=2, tp_degree=2, dp_degree=2) + tgt_3d = model_3d_desc(pp_degree=1, tp_degree=2, dp_degree=1) + + new_3d_map = _do_reshape(src_3d, tgt_3d) + + assert new_3d_map[0].get_data(pp_index=0, tp_index=0) == [0, 4, 2, 6] + assert new_3d_map[0].get_data(pp_index=0, tp_index=1) == [1, 5, 3, 7] + + +def test_reshape_222_to_122(): + src_3d = model_3d_desc(pp_degree=2, tp_degree=2, dp_degree=2) + tgt_3d = model_3d_desc(pp_degree=1, tp_degree=2, dp_degree=2) + + new_3d_map = _do_reshape(src_3d, tgt_3d) + + assert new_3d_map[0].get_data(pp_index=0, tp_index=0) == [0, 4] + assert new_3d_map[0].get_data(pp_index=0, tp_index=1) == [1, 5] + assert new_3d_map[1].get_data(pp_index=0, tp_index=0) == [2, 6] + assert new_3d_map[1].get_data(pp_index=0, tp_index=1) == [3, 7] + + +def test_reshape_222_to_211(): + src_3d = model_3d_desc(pp_degree=2, tp_degree=2, dp_degree=2) + tgt_3d = model_3d_desc(pp_degree=2, tp_degree=1, dp_degree=1) + + new_3d_map = _do_reshape(src_3d, tgt_3d) + + assert new_3d_map[0].get_data(pp_index=0, tp_index=0) == [0, 4, 1, 5] + assert new_3d_map[0].get_data(pp_index=1, tp_index=0) == [2, 6, 3, 7] diff --git a/tests/unit/util.py b/tests/unit/util.py index 79a459da3c14..0aa72a2ad032 100644 --- a/tests/unit/util.py +++ b/tests/unit/util.py @@ -25,3 +25,23 @@ def bf16_required_version_check(): return True else: return False + + +def required_minimum_torch_version(major_version, minor_version): + TORCH_MAJOR = int(torch.__version__.split('.')[0]) + TORCH_MINOR = int(torch.__version__.split('.')[1]) + + if TORCH_MAJOR < major_version: + return False + + return TORCH_MAJOR > major_version or TORCH_MINOR >= minor_version + + +def required_maximum_torch_version(major_version, minor_version): + TORCH_MAJOR = int(torch.__version__.split('.')[0]) + TORCH_MINOR = int(torch.__version__.split('.')[1]) + + if TORCH_MAJOR > major_version: + return False + + return TORCH_MAJOR < major_version or TORCH_MINOR <= minor_version