-
Notifications
You must be signed in to change notification settings - Fork 4.6k
Checkpoint reshaping #1953
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Checkpoint reshaping #1953
Changes from all commits
Commits
Show all changes
71 commits
Select commit
Hold shift + click to select a range
24fe700
unit test, remove exception, add notes
jeffra 70a68d0
Merge branch 'master' of github.com:microsoft/DeepSpeed into elastic-…
tjruwase aafa4e5
Move param_shapes to model files
tjruwase 162c19b
Remove hard-coded constants
tjruwase 84c5d17
Merge branch 'olruwase/relocate_param_shapes' of github.com:microsoft…
tjruwase 59e86dd
Merge branch 'master' into olruwase/relocate_param_shapes
tjruwase 680e620
Conditioned to zero optimizer
tjruwase 8bf3c4e
Merge branch 'olruwase/relocate_param_shapes' of github.com:microsoft…
tjruwase f1b5d16
Add zero checkpoint merging
tjruwase 58d3495
Merge branch 'olruwase/relocate_param_shapes' of github.com:microsoft…
tjruwase 145638d
Merge branch 'master' into olruwase/relocate_param_shapes
jeffra fd8c3e6
Print checkpoint version
tjruwase d85a6df
Merge branch 'olruwase/relocate_param_shapes' of github.com:microsoft…
tjruwase c642600
Merge with relocate_param_shapes
tjruwase c8689fd
Reshape zero_* ckpt files
tjruwase 4a86c1a
Merge zero* files contraction
tjruwase f5db8df
Utils for 3D contraction reshaping
tjruwase d5c6843
Rebase
tjruwase e617920
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase ef8a4a7
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase c12a4e7
Remove bogus import
tjruwase 0b2c33b
Merge branch 'olruwase/elastic-ckpt-refresh' of github.com:microsoft/…
tjruwase 86efe30
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase 1031b32
Support bf16_zero ckpts
tjruwase 6f29465
Merge branch 'olruwase/elastic-ckpt-refresh' of github.com:microsoft/…
tjruwase 8f23728
Merge branch 'master' of github.com:microsoft/DeepSpeed into olruwase…
tjruwase fd1a377
Add param slice mappings
tjruwase 3d4a27b
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase 10083db
Load universal checkpoints
tjruwase 567454a
Merge branch 'olruwase/elastic-ckpt-refresh' of github.com:microsoft/…
tjruwase 22c7550
Per group mappings from Stas
tjruwase 5df4135
Hack to load bf16 zero files
tjruwase ae2825f
Param attributes
tjruwase d11a8dc
WIP
tjruwase 7948c45
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase 691b29d
Fix api bug
tjruwase a05f953
Merge branch 'olruwase/elastic-ckpt-refresh' of github.com:microsoft/…
tjruwase c0a42d3
Update lp with local/remote hp
tjruwase b4ca455
Disable vocab padding handling
tjruwase b8b54c8
Update z2 checkpoint
tjruwase be86df9
Remove debug prints
tjruwase c87543b
Remove debug prints; Rebase unit test
tjruwase c18ff2d
Add reshape assert
tjruwase 4ea36b7
Padding
tjruwase 0371581
Typo
tjruwase a74abc1
Catch nonexistent checkpoint path
tjruwase 2b707f2
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase 529dbae
Cleanup
tjruwase e126d2e
Merge branch 'olruwase/elastic-ckpt-refresh' of github.com:microsoft/…
tjruwase 9e2766f
Restore checkpoint state comparisons
tjruwase 5c90ef1
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase 726982b
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase add1d0c
Merge branch 'master' into olruwase/elastic-ckpt-refresh
jeffra 5fca3db
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase 901b1e6
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase 30896de
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase 93934f6
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase ecb3dc8
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase 6c7d947
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase cd8dea7
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase 4217be2
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase 206e630
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase 14980ad
Add torch version guards
tjruwase f314581
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase 868c463
More precise avoidance of false positives.
tjruwase e22487a
Merge branch 'olruwase/elastic-ckpt-refresh' of github.com:microsoft/…
tjruwase e0da15f
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase 623430e
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase 2556578
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase bf57d81
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase e4a5a46
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| from .reshape_meg_2d import reshape_meg_2d_parallel | ||
|
|
||
| from .deepspeed_checkpoint import DeepSpeedCheckpoint | ||
|
|
||
| from .utils import (get_layer_ckpt_name_for_rank, | ||
| get_model_ckpt_name_for_rank, | ||
| get_zero_ckpt_name_for_rank) | ||
|
|
||
| from .reshape_utils import (merge_state) | ||
|
|
||
| from .reshape_3d_utils import (model_3d_desc, get_model_3d_descriptor) | ||
|
|
||
| from .zero_checkpoint import ZeROCheckpoint |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,316 @@ | ||
| import os | ||
| from typing import Dict | ||
| import torch | ||
|
|
||
| from .reshape_3d_utils import model_3d_desc | ||
| from .reshape_utils import (basic_folder_validation, | ||
| merge_state, | ||
| partition_data, | ||
| get_files, | ||
| get_files_with_prefix) | ||
|
|
||
| from .constants import (ZERO_FILE_PREFIX, MODEL_FILE_PREFIX, LAYER_FILE_PREFIX) | ||
|
|
||
| from .reshape_meg_2d import reshape_meg_2d_parallel, meg_2d_parallel_map | ||
| from .zero_checkpoint import ZeROCheckpoint | ||
| from .constants import * | ||
|
|
||
| EMBEDDING_LAYER_INDEX = 0 | ||
| FINAL_LAYER_NORM_INDEX = -1 | ||
| ARGS_KEY = 'args' | ||
| CHECKPOINT_INFO_KEY = 'checkpoint_info' | ||
| ITERATION_KEY = 'iteration' | ||
|
|
||
| SEQUENTIAL_LAYERS = [ | ||
| 'input_layernorm.weight', | ||
| 'input_layernorm.bias', | ||
| 'self_attention.dense.bias', | ||
| 'post_attention_layernorm.weight', | ||
| 'post_attention_layernorm.bias', | ||
| 'mlp.dense_4h_to_h.bias', | ||
| 'position_embeddings.weight' | ||
| ] | ||
|
|
||
| LAYER_CONCAT_DIM = {'self_attention.dense.weight': 1, 'mlp.dense_4h_to_h.weight': 1} | ||
|
|
||
|
|
||
| class DeepSpeedCheckpoint(object): | ||
| def __init__(self, dir, tp_degree=None, pp_degree=None, dp_degree=None): | ||
| self.dir = dir | ||
| self._validate_folder(dir) | ||
|
|
||
| self.file_list = get_files(dir) | ||
| self.zero_files = get_files_with_prefix(self.file_list, ZERO_FILE_PREFIX) | ||
| self.layer_files = get_files_with_prefix(self.file_list, LAYER_FILE_PREFIX) | ||
| self.mp_rank_files = get_files_with_prefix(self.file_list, MODEL_FILE_PREFIX) | ||
|
|
||
| self.layer_keys = self._get_layer_keys() | ||
| self.layer_count = len(self.layer_keys) | ||
| self.original_tp_degree = len( | ||
| get_files_with_prefix(self.layer_files, | ||
| f'{LAYER_FILE_PREFIX}01')) | ||
| self.original_pp_degree = len(self.mp_rank_files) // self.original_tp_degree | ||
| self.original_dp_degree = max( | ||
| 1, | ||
| len(self.zero_files) // (self.original_pp_degree * self.original_tp_degree)) | ||
|
|
||
| self.tp_degree = self.original_tp_degree if tp_degree is None else tp_degree | ||
| self.pp_degree = self.original_pp_degree if pp_degree is None else pp_degree | ||
| self.dp_degree = self.original_dp_degree if dp_degree is None else dp_degree | ||
|
|
||
| self.original_world_size = self.original_tp_degree * self.original_pp_degree * self.original_dp_degree | ||
| self.world_size = self.tp_degree * self.pp_degree * self.dp_degree | ||
|
|
||
| self.old_2d_map = meg_2d_parallel_map(self.original_pp_degree, | ||
| self.original_tp_degree) | ||
| self.old_2d_map.simple_init() | ||
| self.new_2d_map = reshape_meg_2d_parallel(old_pp_degree=self.original_pp_degree, | ||
| old_tp_degree=self.original_tp_degree, | ||
| new_pp_degree=self.pp_degree, | ||
| new_tp_degree=self.tp_degree) | ||
|
|
||
| self.zero_checkpoint = ZeROCheckpoint(dir) | ||
tjruwase marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if self.is_change_pp_degree() or self.is_change_tp_degree( | ||
| ) or self.is_change_dp_degree(): | ||
| self.zero_checkpoint.reshape( | ||
| model_3d_desc(self.pp_degree, | ||
| self.tp_degree, | ||
| self.dp_degree)) | ||
|
|
||
| self.global_state = {} | ||
|
|
||
| self._sanity_check() | ||
| self.pp_to_transformer_map = self._build_pp_transformer_map() | ||
| self.transformer_file_map = self._build_transformer_file_map() | ||
| self.tp_to_embedding_map = self._build_tp_other_layer_map(EMBEDDING_LAYER_INDEX) | ||
| self.tp_to_final_norm_map = self._build_tp_other_layer_map( | ||
| FINAL_LAYER_NORM_INDEX) | ||
| self._build_global_state() | ||
|
|
||
| def is_change_tp_degree(self): | ||
| return self.tp_degree != self.original_tp_degree | ||
|
|
||
| def is_change_pp_degree(self): | ||
| return self.pp_degree != self.original_pp_degree | ||
|
|
||
| def is_change_dp_degree(self): | ||
| return self.dp_degree != self.original_dp_degree | ||
|
|
||
| def show_2d_mapping(self): | ||
| print(f'reshaped 2d map ---- begin') | ||
|
|
||
| for i in range(self.pp_degree): | ||
| for j in range(self.tp_degree): | ||
| file_list = self.get_2d_parallel_files(pp_index=i, tp_index=j) | ||
| print(f'[{i}, {j}] = {file_list}') | ||
|
|
||
| print(f'reshaped 2d map ---- end') | ||
|
|
||
| def show_tp_embedding_map(self): | ||
| self._dump_mapping(self.tp_to_embedding_map, 'tp_to_embedding_layers') | ||
|
|
||
| def show_tp_final_norm_map(self): | ||
| self._dump_mapping(self.tp_to_final_norm_map, 'tp_to_final_norm_layers') | ||
|
|
||
| def show_pp_tranformer_map(self): | ||
| self._dump_mapping(self.pp_to_transformer_map, 'pp_to_tranformer_layers') | ||
|
|
||
| def show_transformer_file_map(self): | ||
| self._dump_mapping(self.transformer_file_map, 'rank_to_tranformer_files') | ||
|
|
||
| def _build_global_state(self): | ||
| sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu')) | ||
| self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0) | ||
| self.global_state[ARGS_KEY] = sd.get(ARGS_KEY, None) | ||
|
|
||
| def get_zero_checkpoint_state(self, pp_index, tp_index, dp_index) -> dict: | ||
| return self.zero_checkpoint.get_state_for_rank(pp_index=pp_index, | ||
| tp_index=tp_index, | ||
| dp_index=dp_index, | ||
| keys_to_ignore=[PARAM_SHAPES]) | ||
|
|
||
| def get_zero_files(self, pp_index, tp_index, dp_index) -> list: | ||
| return self.zero_checkpoint.get_files_for_rank(pp_index=pp_index, | ||
| tp_index=tp_index, | ||
| dp_index=dp_index) | ||
|
|
||
| def get_embedding_layer_id(self): | ||
| return self.layer_keys[EMBEDDING_LAYER_INDEX] | ||
|
|
||
| def get_final_norm_layer_id(self): | ||
| return self.layer_keys[FINAL_LAYER_NORM_INDEX] | ||
|
|
||
| def get_iteration(self): | ||
| if not ITERATION_KEY in self.global_state: | ||
| sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu')) | ||
| self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0) | ||
|
|
||
| return self.global_state[ITERATION_KEY] | ||
|
|
||
| def get_embedding_state(self, tp_index: int) -> Dict: | ||
| assert tp_index in self.tp_to_embedding_map.keys() | ||
| sd_list = [ | ||
| torch.load(fname, | ||
| map_location=torch.device('cpu')) | ||
| for fname in self.tp_to_embedding_map[tp_index] | ||
| ] | ||
| sd = self._merge_state_dicts(sd_list) | ||
| return sd | ||
|
|
||
| def get_embedding_files(self, tp_index: int) -> list: | ||
| assert tp_index in self.tp_to_embedding_map.keys() | ||
| return self.tp_to_embedding_map[tp_index] | ||
|
|
||
| def _get_checkpoint_value(self, key): | ||
| if not key in self.global_state: | ||
| sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu')) | ||
| self.global_state[key] = sd.get(key, None) | ||
|
|
||
| return self.global_state[key] | ||
|
|
||
| def get_args(self): | ||
| return self._get_checkpoint_value(ARGS_KEY) | ||
|
|
||
| def get_checkpoint_info(self): | ||
| return self._get_checkpoint_value(CHECKPOINT_INFO_KEY) | ||
|
|
||
| def get_2d_parallel_state(self, tp_index: int, pp_index: int) -> dict: | ||
| assert tp_index < self.tp_degree | ||
| assert pp_index < self.pp_degree | ||
| fname_list = self.get_2d_parallel_files(tp_index=tp_index, pp_index=pp_index) | ||
| sd_list = [ | ||
| torch.load(fname, | ||
| map_location=torch.device('cpu')) for fname in fname_list | ||
| ] | ||
|
|
||
| merged_sd = None | ||
| for sd in sd_list: | ||
| if merged_sd is None: | ||
| merged_sd = sd | ||
| else: | ||
| merged_sd = merge_state(merged_sd, sd) | ||
|
|
||
| return merged_sd | ||
|
|
||
| def get_transformer_state(self, tp_index: int, pp_index: int) -> list: | ||
| assert tp_index < self.tp_degree | ||
| assert pp_index < self.pp_degree | ||
| t_list = [] | ||
| for fname_list in self.transformer_file_map[(tp_index, pp_index)]: | ||
| sd_list = [ | ||
| torch.load(fname, | ||
| map_location=torch.device('cpu')) for fname in fname_list | ||
| ] | ||
| sd = self._merge_state_dicts(sd_list) | ||
| t_list.append(sd) | ||
| return t_list | ||
|
|
||
| def get_pp_transformer_map(self, pp_index: int) -> list: | ||
| assert pp_index < self.pp_degree | ||
| return self.pp_to_transformer_map[pp_index] | ||
|
|
||
| def get_final_norm_state(self, tp_index: int) -> Dict: | ||
| assert tp_index in self.tp_to_final_norm_map.keys() | ||
| sd = torch.load(self.tp_to_final_norm_map[tp_index][0], | ||
| map_location=torch.device('cpu')) | ||
| return sd | ||
|
|
||
| def get_final_norm_files(self, tp_index: int) -> list: | ||
| assert tp_index in self.tp_to_final_norm_map.keys() | ||
| return self.tp_to_final_norm_map[tp_index] | ||
|
|
||
| def _build_tp_other_layer_map(self, layer_index: int): | ||
| assert layer_index < len(self.layer_files) | ||
| layer_files = get_files_with_prefix(self.layer_files, | ||
| self.layer_keys[layer_index]) | ||
| layer_file_partitions = partition_data(layer_files, self.tp_degree) | ||
| data_map = {i: flist for i, flist in enumerate(layer_file_partitions)} | ||
| return data_map | ||
|
|
||
| def get_2d_parallel_files(self, tp_index: int, pp_index: int) -> list: | ||
| assert tp_index < self.tp_degree | ||
| assert pp_index < self.pp_degree | ||
| file_indices = self.new_2d_map.get_data(pp_index=pp_index, tp_index=tp_index) | ||
| return [self.mp_rank_files[i] for i in file_indices] | ||
|
|
||
| def _build_pp_transformer_map(self): | ||
| data_map = {} | ||
| transformer_layers = self.layer_keys[1:-1] | ||
| layers_per_pp = len(transformer_layers) // self.pp_degree | ||
| data_map = { | ||
| i: transformer_layers[i * layers_per_pp:(i + 1) * layers_per_pp] | ||
| for i in range(0, | ||
| self.pp_degree) | ||
| } | ||
| return data_map | ||
tjruwase marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def _dump_mapping(self, data_map, map_tag=None): | ||
| if map_tag is not None: | ||
| print(f'Dump mapping: {map_tag}') | ||
| for k, v in data_map.items(): | ||
| print(f'{k} = {v}') | ||
|
|
||
| def _build_transformer_file_map(self): | ||
| transformer_layer_keys = self.layer_keys[1:-1] | ||
| file_map = {} | ||
| # XXX: this is not guaranteed | ||
| layers_per_pp = len(transformer_layer_keys) // self.pp_degree | ||
| if layers_per_pp == 0: | ||
| layers_per_pp = 1 | ||
| #print(f"{transformer_layer_keys} {layers_per_pp}") | ||
| for key_index, layer_key in enumerate(transformer_layer_keys): | ||
| pp_index = key_index // layers_per_pp | ||
| layer_files = get_files_with_prefix(self.layer_files, layer_key) | ||
| layer_file_partitions = partition_data(layer_files, self.tp_degree) | ||
| for tp_index in range(self.tp_degree): | ||
| map_key = (tp_index, pp_index) | ||
| if not map_key in file_map.keys(): | ||
| file_map[map_key] = [] | ||
| file_map[map_key].append(layer_file_partitions[tp_index]) | ||
|
|
||
| return file_map | ||
|
|
||
| def _sanity_check(self): | ||
| assert len(self.mp_rank_files) % self.tp_degree == 0 | ||
| assert len(self.zero_files) % (self.pp_degree * self.tp_degree) == 0 | ||
| assert len(self.layer_keys) > 2 | ||
| # XXX: fix me - isn't always the case | ||
| # only true with --pp-partition-method 'type:transformer|embedding' \ | ||
| # assert (len(self.layer_keys) - 2) % self.pp_degree == 0 | ||
|
|
||
| def validate_files(self): | ||
| for file in self.file_list: | ||
| if not os.path.isfile(file): | ||
| print(f'Error: {file} is not existent') | ||
|
|
||
| def _get_layer_keys(self): | ||
| key_set = set() | ||
| key_len = len(LAYER_FILE_PREFIX) + 2 | ||
| for file_path in self.layer_files: | ||
| _, fname = os.path.split(file_path) | ||
| key_set.add(fname[:key_len]) | ||
| return sorted(list(key_set)) | ||
|
|
||
| def _merge_state_dicts(self, sd_list): | ||
| merged_sd = {} | ||
| for key in sd_list[0].keys(): | ||
| if not key in SEQUENTIAL_LAYERS: | ||
| cat_dim = LAYER_CONCAT_DIM.get(key, 0) | ||
| merged_sd[key] = torch.cat([sd[key] for sd in sd_list], dim=cat_dim) | ||
| else: | ||
| merged_sd[key] = sd_list[0][key] | ||
|
|
||
| return merged_sd | ||
|
|
||
| def _validate_folder(self, dir): | ||
| basic_folder_validation(dir) | ||
|
|
||
| file_list = get_files(dir) | ||
|
|
||
| for file_prefix in [ | ||
| MODEL_FILE_PREFIX, | ||
| LAYER_FILE_PREFIX, | ||
| f'{LAYER_FILE_PREFIX}01' | ||
| ]: | ||
| ckpt_files = get_files_with_prefix(file_list, file_prefix) | ||
| assert len(ckpt_files) > 0, f'{dir} seems a bogus DeepSpeed checkpoint folder: Cannot find {file_prefix}* files in there.' | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.