Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
24fe700
unit test, remove exception, add notes
jeffra Jan 21, 2022
70a68d0
Merge branch 'master' of github.com:microsoft/DeepSpeed into elastic-…
tjruwase Jan 27, 2022
aafa4e5
Move param_shapes to model files
tjruwase Jan 29, 2022
162c19b
Remove hard-coded constants
tjruwase Jan 29, 2022
84c5d17
Merge branch 'olruwase/relocate_param_shapes' of github.com:microsoft…
tjruwase Jan 29, 2022
59e86dd
Merge branch 'master' into olruwase/relocate_param_shapes
tjruwase Jan 29, 2022
680e620
Conditioned to zero optimizer
tjruwase Jan 29, 2022
8bf3c4e
Merge branch 'olruwase/relocate_param_shapes' of github.com:microsoft…
tjruwase Jan 29, 2022
f1b5d16
Add zero checkpoint merging
tjruwase Jan 29, 2022
58d3495
Merge branch 'olruwase/relocate_param_shapes' of github.com:microsoft…
tjruwase Jan 29, 2022
145638d
Merge branch 'master' into olruwase/relocate_param_shapes
jeffra Jan 31, 2022
fd8c3e6
Print checkpoint version
tjruwase Jan 31, 2022
d85a6df
Merge branch 'olruwase/relocate_param_shapes' of github.com:microsoft…
tjruwase Jan 31, 2022
c642600
Merge with relocate_param_shapes
tjruwase Jan 31, 2022
c8689fd
Reshape zero_* ckpt files
tjruwase Feb 7, 2022
4a86c1a
Merge zero* files contraction
tjruwase Feb 8, 2022
f5db8df
Utils for 3D contraction reshaping
tjruwase Feb 23, 2022
d5c6843
Rebase
tjruwase Apr 19, 2022
e617920
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase May 14, 2022
ef8a4a7
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase May 16, 2022
c12a4e7
Remove bogus import
tjruwase May 16, 2022
0b2c33b
Merge branch 'olruwase/elastic-ckpt-refresh' of github.com:microsoft/…
tjruwase May 16, 2022
86efe30
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase May 18, 2022
1031b32
Support bf16_zero ckpts
tjruwase May 20, 2022
6f29465
Merge branch 'olruwase/elastic-ckpt-refresh' of github.com:microsoft/…
tjruwase May 20, 2022
8f23728
Merge branch 'master' of github.com:microsoft/DeepSpeed into olruwase…
tjruwase May 20, 2022
fd1a377
Add param slice mappings
tjruwase May 20, 2022
3d4a27b
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase May 20, 2022
10083db
Load universal checkpoints
tjruwase May 24, 2022
567454a
Merge branch 'olruwase/elastic-ckpt-refresh' of github.com:microsoft/…
tjruwase May 24, 2022
22c7550
Per group mappings from Stas
tjruwase May 25, 2022
5df4135
Hack to load bf16 zero files
tjruwase May 26, 2022
ae2825f
Param attributes
tjruwase May 31, 2022
d11a8dc
WIP
tjruwase Jun 1, 2022
7948c45
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase Jun 2, 2022
691b29d
Fix api bug
tjruwase Jun 2, 2022
a05f953
Merge branch 'olruwase/elastic-ckpt-refresh' of github.com:microsoft/…
tjruwase Jun 2, 2022
c0a42d3
Update lp with local/remote hp
tjruwase Jun 2, 2022
b4ca455
Disable vocab padding handling
tjruwase Jun 4, 2022
b8b54c8
Update z2 checkpoint
tjruwase Jun 6, 2022
be86df9
Remove debug prints
tjruwase Jun 6, 2022
c87543b
Remove debug prints; Rebase unit test
tjruwase Jun 6, 2022
c18ff2d
Add reshape assert
tjruwase Jun 6, 2022
4ea36b7
Padding
tjruwase Jun 6, 2022
0371581
Typo
tjruwase Jun 6, 2022
a74abc1
Catch nonexistent checkpoint path
tjruwase Jun 7, 2022
2b707f2
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase Jun 7, 2022
529dbae
Cleanup
tjruwase Jun 7, 2022
e126d2e
Merge branch 'olruwase/elastic-ckpt-refresh' of github.com:microsoft/…
tjruwase Jun 7, 2022
9e2766f
Restore checkpoint state comparisons
tjruwase Jun 10, 2022
5c90ef1
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase Jun 13, 2022
726982b
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase Jun 14, 2022
add1d0c
Merge branch 'master' into olruwase/elastic-ckpt-refresh
jeffra Jun 16, 2022
5fca3db
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase Jun 16, 2022
901b1e6
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase Jun 20, 2022
30896de
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase Jun 20, 2022
93934f6
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase Jun 21, 2022
ecb3dc8
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase Jun 22, 2022
6c7d947
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase Jul 4, 2022
cd8dea7
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase Jul 7, 2022
4217be2
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase Jul 11, 2022
206e630
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase Jul 14, 2022
14980ad
Add torch version guards
tjruwase Jul 18, 2022
f314581
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase Jul 18, 2022
868c463
More precise avoidance of false positives.
tjruwase Jul 19, 2022
e22487a
Merge branch 'olruwase/elastic-ckpt-refresh' of github.com:microsoft/…
tjruwase Jul 19, 2022
e0da15f
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase Jul 19, 2022
623430e
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase Jul 19, 2022
2556578
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase Jul 19, 2022
bf57d81
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase Jul 20, 2022
e4a5a46
Merge branch 'master' into olruwase/elastic-ckpt-refresh
tjruwase Jul 20, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions deepspeed/checkpoint/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from .reshape_meg_2d import reshape_meg_2d_parallel

from .deepspeed_checkpoint import DeepSpeedCheckpoint

from .utils import (get_layer_ckpt_name_for_rank,
get_model_ckpt_name_for_rank,
get_zero_ckpt_name_for_rank)

from .reshape_utils import (merge_state)

from .reshape_3d_utils import (model_3d_desc, get_model_3d_descriptor)

from .zero_checkpoint import ZeROCheckpoint
19 changes: 17 additions & 2 deletions deepspeed/checkpoint/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,30 @@

BASE_OPTIMIZER_STATE = 'base_optimizer_state'
SINGLE_PARTITION_OF_FP32_GROUPS = "single_partition_of_fp32_groups"
GROUPS_PADDING = 'groups_padding'

GROUP_PADDINGS = 'group_paddings'
PARTITION_COUNT = 'partition_count'
ZERO_STAGE = 'zero_stage'
CLIP_GRAD = 'clip_grad'
PARAM_SLICE_MAPPINGS = 'param_slice_mappings'
FP32_WEIGHT_KEY = "fp32"

#########################################
# Module checkpoint keys
#########################################
PARAM_SHAPES = 'param_shapes'
BUFFER_NAMES = 'buffer_names'

#########################################
# Checkpoint naming constants
#########################################
MODEL_FILE_PREFIX = 'mp_rank_'
ZERO_FILE_PREFIX = 'bf16_' + 'zero_pp_rank_'
OPTIM_FILE_SUFFIX = '_optim_states.pt'
MODEL_FILE_SUFFIX = '_model_states.pt'
LAYER_FILE_PREFIX = 'layer_'
BF16_ZERO_FILE_PREFIX = ZERO_FILE_PREFIX

#########################################
# Checkpoint utility keys
#########################################
DS_VERSION = 'ds_version'
316 changes: 316 additions & 0 deletions deepspeed/checkpoint/deepspeed_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,316 @@
import os
from typing import Dict
import torch

from .reshape_3d_utils import model_3d_desc
from .reshape_utils import (basic_folder_validation,
merge_state,
partition_data,
get_files,
get_files_with_prefix)

from .constants import (ZERO_FILE_PREFIX, MODEL_FILE_PREFIX, LAYER_FILE_PREFIX)

from .reshape_meg_2d import reshape_meg_2d_parallel, meg_2d_parallel_map
from .zero_checkpoint import ZeROCheckpoint
from .constants import *

EMBEDDING_LAYER_INDEX = 0
FINAL_LAYER_NORM_INDEX = -1
ARGS_KEY = 'args'
CHECKPOINT_INFO_KEY = 'checkpoint_info'
ITERATION_KEY = 'iteration'

SEQUENTIAL_LAYERS = [
'input_layernorm.weight',
'input_layernorm.bias',
'self_attention.dense.bias',
'post_attention_layernorm.weight',
'post_attention_layernorm.bias',
'mlp.dense_4h_to_h.bias',
'position_embeddings.weight'
]

LAYER_CONCAT_DIM = {'self_attention.dense.weight': 1, 'mlp.dense_4h_to_h.weight': 1}


class DeepSpeedCheckpoint(object):
def __init__(self, dir, tp_degree=None, pp_degree=None, dp_degree=None):
self.dir = dir
self._validate_folder(dir)

self.file_list = get_files(dir)
self.zero_files = get_files_with_prefix(self.file_list, ZERO_FILE_PREFIX)
self.layer_files = get_files_with_prefix(self.file_list, LAYER_FILE_PREFIX)
self.mp_rank_files = get_files_with_prefix(self.file_list, MODEL_FILE_PREFIX)

self.layer_keys = self._get_layer_keys()
self.layer_count = len(self.layer_keys)
self.original_tp_degree = len(
get_files_with_prefix(self.layer_files,
f'{LAYER_FILE_PREFIX}01'))
self.original_pp_degree = len(self.mp_rank_files) // self.original_tp_degree
self.original_dp_degree = max(
1,
len(self.zero_files) // (self.original_pp_degree * self.original_tp_degree))

self.tp_degree = self.original_tp_degree if tp_degree is None else tp_degree
self.pp_degree = self.original_pp_degree if pp_degree is None else pp_degree
self.dp_degree = self.original_dp_degree if dp_degree is None else dp_degree

self.original_world_size = self.original_tp_degree * self.original_pp_degree * self.original_dp_degree
self.world_size = self.tp_degree * self.pp_degree * self.dp_degree

self.old_2d_map = meg_2d_parallel_map(self.original_pp_degree,
self.original_tp_degree)
self.old_2d_map.simple_init()
self.new_2d_map = reshape_meg_2d_parallel(old_pp_degree=self.original_pp_degree,
old_tp_degree=self.original_tp_degree,
new_pp_degree=self.pp_degree,
new_tp_degree=self.tp_degree)

self.zero_checkpoint = ZeROCheckpoint(dir)
if self.is_change_pp_degree() or self.is_change_tp_degree(
) or self.is_change_dp_degree():
self.zero_checkpoint.reshape(
model_3d_desc(self.pp_degree,
self.tp_degree,
self.dp_degree))

self.global_state = {}

self._sanity_check()
self.pp_to_transformer_map = self._build_pp_transformer_map()
self.transformer_file_map = self._build_transformer_file_map()
self.tp_to_embedding_map = self._build_tp_other_layer_map(EMBEDDING_LAYER_INDEX)
self.tp_to_final_norm_map = self._build_tp_other_layer_map(
FINAL_LAYER_NORM_INDEX)
self._build_global_state()

def is_change_tp_degree(self):
return self.tp_degree != self.original_tp_degree

def is_change_pp_degree(self):
return self.pp_degree != self.original_pp_degree

def is_change_dp_degree(self):
return self.dp_degree != self.original_dp_degree

def show_2d_mapping(self):
print(f'reshaped 2d map ---- begin')

for i in range(self.pp_degree):
for j in range(self.tp_degree):
file_list = self.get_2d_parallel_files(pp_index=i, tp_index=j)
print(f'[{i}, {j}] = {file_list}')

print(f'reshaped 2d map ---- end')

def show_tp_embedding_map(self):
self._dump_mapping(self.tp_to_embedding_map, 'tp_to_embedding_layers')

def show_tp_final_norm_map(self):
self._dump_mapping(self.tp_to_final_norm_map, 'tp_to_final_norm_layers')

def show_pp_tranformer_map(self):
self._dump_mapping(self.pp_to_transformer_map, 'pp_to_tranformer_layers')

def show_transformer_file_map(self):
self._dump_mapping(self.transformer_file_map, 'rank_to_tranformer_files')

def _build_global_state(self):
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'))
self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0)
self.global_state[ARGS_KEY] = sd.get(ARGS_KEY, None)

def get_zero_checkpoint_state(self, pp_index, tp_index, dp_index) -> dict:
return self.zero_checkpoint.get_state_for_rank(pp_index=pp_index,
tp_index=tp_index,
dp_index=dp_index,
keys_to_ignore=[PARAM_SHAPES])

def get_zero_files(self, pp_index, tp_index, dp_index) -> list:
return self.zero_checkpoint.get_files_for_rank(pp_index=pp_index,
tp_index=tp_index,
dp_index=dp_index)

def get_embedding_layer_id(self):
return self.layer_keys[EMBEDDING_LAYER_INDEX]

def get_final_norm_layer_id(self):
return self.layer_keys[FINAL_LAYER_NORM_INDEX]

def get_iteration(self):
if not ITERATION_KEY in self.global_state:
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'))
self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0)

return self.global_state[ITERATION_KEY]

def get_embedding_state(self, tp_index: int) -> Dict:
assert tp_index in self.tp_to_embedding_map.keys()
sd_list = [
torch.load(fname,
map_location=torch.device('cpu'))
for fname in self.tp_to_embedding_map[tp_index]
]
sd = self._merge_state_dicts(sd_list)
return sd

def get_embedding_files(self, tp_index: int) -> list:
assert tp_index in self.tp_to_embedding_map.keys()
return self.tp_to_embedding_map[tp_index]

def _get_checkpoint_value(self, key):
if not key in self.global_state:
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'))
self.global_state[key] = sd.get(key, None)

return self.global_state[key]

def get_args(self):
return self._get_checkpoint_value(ARGS_KEY)

def get_checkpoint_info(self):
return self._get_checkpoint_value(CHECKPOINT_INFO_KEY)

def get_2d_parallel_state(self, tp_index: int, pp_index: int) -> dict:
assert tp_index < self.tp_degree
assert pp_index < self.pp_degree
fname_list = self.get_2d_parallel_files(tp_index=tp_index, pp_index=pp_index)
sd_list = [
torch.load(fname,
map_location=torch.device('cpu')) for fname in fname_list
]

merged_sd = None
for sd in sd_list:
if merged_sd is None:
merged_sd = sd
else:
merged_sd = merge_state(merged_sd, sd)

return merged_sd

def get_transformer_state(self, tp_index: int, pp_index: int) -> list:
assert tp_index < self.tp_degree
assert pp_index < self.pp_degree
t_list = []
for fname_list in self.transformer_file_map[(tp_index, pp_index)]:
sd_list = [
torch.load(fname,
map_location=torch.device('cpu')) for fname in fname_list
]
sd = self._merge_state_dicts(sd_list)
t_list.append(sd)
return t_list

def get_pp_transformer_map(self, pp_index: int) -> list:
assert pp_index < self.pp_degree
return self.pp_to_transformer_map[pp_index]

def get_final_norm_state(self, tp_index: int) -> Dict:
assert tp_index in self.tp_to_final_norm_map.keys()
sd = torch.load(self.tp_to_final_norm_map[tp_index][0],
map_location=torch.device('cpu'))
return sd

def get_final_norm_files(self, tp_index: int) -> list:
assert tp_index in self.tp_to_final_norm_map.keys()
return self.tp_to_final_norm_map[tp_index]

def _build_tp_other_layer_map(self, layer_index: int):
assert layer_index < len(self.layer_files)
layer_files = get_files_with_prefix(self.layer_files,
self.layer_keys[layer_index])
layer_file_partitions = partition_data(layer_files, self.tp_degree)
data_map = {i: flist for i, flist in enumerate(layer_file_partitions)}
return data_map

def get_2d_parallel_files(self, tp_index: int, pp_index: int) -> list:
assert tp_index < self.tp_degree
assert pp_index < self.pp_degree
file_indices = self.new_2d_map.get_data(pp_index=pp_index, tp_index=tp_index)
return [self.mp_rank_files[i] for i in file_indices]

def _build_pp_transformer_map(self):
data_map = {}
transformer_layers = self.layer_keys[1:-1]
layers_per_pp = len(transformer_layers) // self.pp_degree
data_map = {
i: transformer_layers[i * layers_per_pp:(i + 1) * layers_per_pp]
for i in range(0,
self.pp_degree)
}
return data_map

def _dump_mapping(self, data_map, map_tag=None):
if map_tag is not None:
print(f'Dump mapping: {map_tag}')
for k, v in data_map.items():
print(f'{k} = {v}')

def _build_transformer_file_map(self):
transformer_layer_keys = self.layer_keys[1:-1]
file_map = {}
# XXX: this is not guaranteed
layers_per_pp = len(transformer_layer_keys) // self.pp_degree
if layers_per_pp == 0:
layers_per_pp = 1
#print(f"{transformer_layer_keys} {layers_per_pp}")
for key_index, layer_key in enumerate(transformer_layer_keys):
pp_index = key_index // layers_per_pp
layer_files = get_files_with_prefix(self.layer_files, layer_key)
layer_file_partitions = partition_data(layer_files, self.tp_degree)
for tp_index in range(self.tp_degree):
map_key = (tp_index, pp_index)
if not map_key in file_map.keys():
file_map[map_key] = []
file_map[map_key].append(layer_file_partitions[tp_index])

return file_map

def _sanity_check(self):
assert len(self.mp_rank_files) % self.tp_degree == 0
assert len(self.zero_files) % (self.pp_degree * self.tp_degree) == 0
assert len(self.layer_keys) > 2
# XXX: fix me - isn't always the case
# only true with --pp-partition-method 'type:transformer|embedding' \
# assert (len(self.layer_keys) - 2) % self.pp_degree == 0

def validate_files(self):
for file in self.file_list:
if not os.path.isfile(file):
print(f'Error: {file} is not existent')

def _get_layer_keys(self):
key_set = set()
key_len = len(LAYER_FILE_PREFIX) + 2
for file_path in self.layer_files:
_, fname = os.path.split(file_path)
key_set.add(fname[:key_len])
return sorted(list(key_set))

def _merge_state_dicts(self, sd_list):
merged_sd = {}
for key in sd_list[0].keys():
if not key in SEQUENTIAL_LAYERS:
cat_dim = LAYER_CONCAT_DIM.get(key, 0)
merged_sd[key] = torch.cat([sd[key] for sd in sd_list], dim=cat_dim)
else:
merged_sd[key] = sd_list[0][key]

return merged_sd

def _validate_folder(self, dir):
basic_folder_validation(dir)

file_list = get_files(dir)

for file_prefix in [
MODEL_FILE_PREFIX,
LAYER_FILE_PREFIX,
f'{LAYER_FILE_PREFIX}01'
]:
ckpt_files = get_files_with_prefix(file_list, file_prefix)
assert len(ckpt_files) > 0, f'{dir} seems a bogus DeepSpeed checkpoint folder: Cannot find {file_prefix}* files in there.'
Loading