diff --git a/megatron/arguments.py b/megatron/arguments.py index 230bd4d65..416e28aa8 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -637,6 +637,8 @@ def _add_learning_rate_args(parser): '(learning rate, warmup iterations, minimum learning ' 'rate, maximum number of iterations, and decay style ' 'from checkpoint and ignore input arguments.') + group.add_argument('--universal-checkpoint', action='store_true', + help='Loading a universal format checkpoint.') return parser diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index d9a30f468..dacbec7dc 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -27,7 +27,8 @@ mpu, print_rank_0, update_num_microbatches, - utils) + utils, + get_tokenizer) from megatron.enums import PositionEmbeddingType _CHECKPOINT_VERSION = None @@ -131,6 +132,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): state_dict['checkpoint_version'] = 3.0 state_dict['iteration'] = iteration state_dict['tokens'] = args.consumed_train_tokens + state_dict['checkpoint_info'] = _checkpoint_info() # DeepSpeed saves the model/optimizer/scheduler if not args.deepspeed: @@ -361,7 +363,8 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True assert args.consumed_valid_samples == 0 if 'args' in state_dict: checkpoint_args = state_dict['args'] - check_checkpoint_args(checkpoint_args) + if not args.universal_checkpoint: + check_checkpoint_args(checkpoint_args) args.consumed_train_samples = getattr(checkpoint_args, 'consumed_train_samples', 0) update_num_microbatches(consumed_samples=args.consumed_train_samples) @@ -468,3 +471,13 @@ def load_biencoder_checkpoint(model, only_query_model=False, print(' successfully loaded {}'.format(checkpoint_name)) return model + + +def _checkpoint_info(): + args = get_args() + tokenizer = get_tokenizer() + + return { + "padded_vocab_size": args.padded_vocab_size, + "original_vocab_size": tokenizer.vocab_size, + } \ No newline at end of file diff --git a/megatron/data/data_samplers.py b/megatron/data/data_samplers.py index c8109b3d2..b74c96f47 100644 --- a/megatron/data/data_samplers.py +++ b/megatron/data/data_samplers.py @@ -55,6 +55,7 @@ def build_pretraining_data_loader(dataset, consumed_samples, num_workers=None): return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, + generator=torch.Generator().manual_seed(args.seed), pin_memory=True) class MegatronPretrainingSampler: diff --git a/megatron/model/fused_layer_norm.py b/megatron/model/fused_layer_norm.py index 563566b70..95b2278ce 100644 --- a/megatron/model/fused_layer_norm.py +++ b/megatron/model/fused_layer_norm.py @@ -26,6 +26,7 @@ import torch.nn.functional as F from torch.nn import init import importlib +from megatron import mpu from megatron import get_args @@ -97,8 +98,26 @@ def reset_parameters(self): def forward(self, input): + + torch.distributed.all_reduce(self.weight, op=torch.distributed.ReduceOp.AVG, group=mpu.get_tensor_model_parallel_group()) + torch.distributed.all_reduce(self.bias, op=torch.distributed.ReduceOp.AVG, group=mpu.get_tensor_model_parallel_group()) + if self.use_meg_ds_fused_layer_norm: return FusedLayerNormAffineFunction.apply( input, self.weight, self.bias, self.normalized_shape, self.eps) else: return F.layer_norm(input, self.normalized_shape, self.weight, self.bias) + + + def forward1(self, input): + # TODO: temporary hack in order to synchronize all layer norms params despite them being + # unsynced at the moment due to a bug in deepspeed's bf16 optimizer + if 1: + tp_world_size = mpu.get_tensor_model_parallel_world_size() + weight = torch.clone(self.weight) + bias = torch.clone(self.bias) + weight = mpu.reduce_from_tensor_model_parallel_region(weight) / tp_world_size + bias = mpu.reduce_from_tensor_model_parallel_region(bias) / tp_world_size + + return FusedLayerNormAffineFunction.apply( + input, weight, bias, self.normalized_shape,self.eps) diff --git a/megatron/training.py b/megatron/training.py index bbf6623e3..5152c820a 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -367,6 +367,32 @@ def get_learning_rate_scheduler(optimizer): return lr_scheduler +def sync_hp_to_lp(optimizer): + + optimizer.update_lp_params() + + # for n,p in model.named_parameters(): + # print(n) + + # if p._hp_mapping is not None: + # #print(f'rank {rank} fixing hp for input_layernorm') + # #p._hp_mapping.update_hp() + + # hp = p._hp_mapping.hp_fragment + + + + # torch.distributed.all_reduce(hp, op=torch.distributed.ReduceOp.AVG, group=mpu.get_tensor_model_parallel_group()) + + # # 3. optim states + # for key in ['exp_avg', 'exp_avg_sq']: + # optim_state_fragment = p._hp_mapping.get_optim_state_fragment(key) + # #print(f'rank {rank} before reduce optim state fragment {key} = {optim_state_fragment}') + # torch.distributed.all_reduce(optim_state_fragment, op=torch.distributed.ReduceOp.AVG, group=mpu.get_tensor_model_parallel_group()) + # #print(f'rank {rank} after reduce optim state fragment {key} = {optim_state_fragment}') + + + def setup_model_and_optimizer(model_provider_func): """Setup model and optimizer.""" args = get_args() @@ -386,12 +412,21 @@ def setup_model_and_optimizer(model_provider_func): if args.deepspeed: print_rank_0("DeepSpeed is enabled.") - pp = mpu.get_pipeline_model_parallel_world_size() + #pp = mpu.get_pipeline_model_parallel_world_size() + + import json + import io + with io.open(args.deepspeed_config, "r", encoding="utf-8") as f: + config = json.load(f) + if args.universal_checkpoint: + config["checkpoint"] = {"load_universal": True} + model, optimizer, _, lr_scheduler = deepspeed.initialize( model=model[0], optimizer=optimizer, + lr_scheduler=lr_scheduler, + config=config, args=args, - lr_scheduler=lr_scheduler ) assert model.fp16_enabled() == args.fp16, "megatron fp16 config does not match deepspeed" @@ -416,8 +451,37 @@ def setup_model_and_optimizer(model_provider_func): torch.distributed.barrier() timers('load-checkpoint').stop() timers.log(['load-checkpoint']) + + + # hp -> lp + if args.deepspeed and args.universal_checkpoint: + sync_hp_to_lp(optimizer) + + else: args.iteration = 0 + + from .utils import dump_weights + dump_weights(f'{args.universal_checkpoint=}', args.iteration, model, optimizer) + + # tp_rank = mpu.get_tensor_model_parallel_rank() + # pp_rank = mpu.get_pipeline_model_parallel_rank() + # dp_rank = mpu.get_data_parallel_rank() + # for n,p in model[0].named_parameters(): + # if 'word_embeddings.weight' not in n: + # continue + # if tp_rank == 0 and pp_rank == 0: + # print(f"{tp_rank=}{pp_rank=}{dp_rank=} bf16 {n=} {p[:10]=}") + # if p._hp_mapping is not None: + # hp = p._hp_mapping.hp_fragment + # print(f'{tp_rank=}{pp_rank=}{dp_rank=} fp32 {n=} {hp[:10]=}') + + # if tp_rank == 0 and pp_rank == mpu.get_pipeline_model_parallel_world_size() - 1: + # print(f"{tp_rank=}{pp_rank=}{dp_rank=} bf16 {n=} {p[:10]=}") + # if p._hp_mapping is not None: + # hp = p._hp_mapping.hp_fragment + # print(f'{tp_rank=}{pp_rank=}{dp_rank=} fp32 {n=} {hp[:10]=}') + # We only support local DDP with multiple micro-batches. if len(model) > 1 or mpu.get_pipeline_model_parallel_world_size() > 1: diff --git a/megatron/utils.py b/megatron/utils.py index 98d2f611c..fe0b09de5 100644 --- a/megatron/utils.py +++ b/megatron/utils.py @@ -392,3 +392,79 @@ def found_kill_switch(): return True else: return False + +def get_fingerprint_header(): + return f"{'min':^13} {'max':^13} {'mean':^13} {'l2 norm':^12} metadata" + +def get_fingerprint(p): + return f"{p.min():13.6e} {p.max():13.6e} {p.mean():13.6e} {p.norm():12.6e}" + + +def dump_weights(preamble, iteration, model, optimizer, tensor=None): + tp_rank = mpu.get_tensor_model_parallel_rank() + pp_rank = mpu.get_pipeline_model_parallel_rank() + dp_rank = mpu.get_data_parallel_rank() + dp_size = mpu.get_data_parallel_world_size() + fn = f"debug-bf16-{iteration}-pp{pp_rank}-tp{tp_rank}-dp{dp_rank}-{preamble}.txt" + + # only care for first and last pp stages and dp0 tp0 + #if not (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()): + # return + + #if not (tp_rank == 0 and dp_rank == 0): + # return + + if tensor is not None: + orig_tensor = tensor + if hasattr(tensor, "_hp_param"): + numel = tensor._hp_param.numel() # // dp_size + tensor = tensor.flatten().narrow(0, 0, numel) + + #print(fn) + with open(fn, "w") as fh: + fh.write(f"{get_fingerprint_header()}\n") + + if tensor is not None: + fh.write(f"{get_fingerprint(tensor)} tensor {tensor.shape}\n") + else: + for n, p in model[0].named_parameters(): + fh.write(f"{get_fingerprint(p)} {n} {p.shape}\n") + + + return + + + # until we figure out how to dump the actual fp32 values don't do this + fn = f"debug-fp32-{iteration}-pp{pp_rank}-tp{tp_rank}-dp{dp_rank}-{preamble}.txt" + with open(fn, "w") as fh: + fh.write(f"{get_fingerprint_header()}\n") + if tensor is not None: + tensor = orig_tensor + if hasattr(tensor, "_hp_param"): + fh.write(f"{get_fingerprint(tensor._hp_param)} tensor {tensor._hp_param.shape}\n") + #fh.write(f"{get_fingerprint(tensor._hp_grad)} tensor grad\n") + else: + fh.write(f"{get_fingerprint(tensor)} tensor {tensor.shape}\n") + #fh.write(f"{get_fingerprint(tensor.grad)} tensor grad\n") + + else: + if hasattr(model[0].module.tied_modules, "embed"): + p = model[0].module.tied_modules.embed.word_embeddings.weight._hp_param + fh.write(f"{get_fingerprint(p)} module.tied_modules.embed.word_embeddings.weight._hp_param {p.shape}\n") + + # for i, param_group in enumerate(optimizer.param_groups): + # fh.write(f"{get_fingerprint(optimizer.fp32_groups_flat_partition[i])} group={i}\n") + #fh.write(f"{i}={optimizer.fp32_groups_flat_partition[i]}\n") + # if mpu.is_pipeline_first_stage(): + # x = optimizer.fp32_groups_flat_partition[0] + # fh.write(f"fp32={x[:402432]}\n") + # if mpu.is_pipeline_last_stage()): + # x = optimizer.fp32_groups_flat_partition[1] + # fh.write(f"fp32={x[-402432:]}\n") + + # import os + # import socket + # hostname = socket.gethostname() + # pid = os.getpid() + # global_rank = torch.distributed.get_rank() + #fn = f"debug-{iteration}-pp{pp_rank}-tp{tp_rank}-dp{dp_rank}-global{global_rank}-{preamble}-{pid}.txt" \ No newline at end of file diff --git a/run_bf16.sh b/run_bf16.sh index fd3a48398..fc884d4af 100755 --- a/run_bf16.sh +++ b/run_bf16.sh @@ -12,7 +12,12 @@ DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'` #DATASET_3="" #DATASET="0.2 ${DATASET_1} 0.3 ${DATASET_2} 0.5 ${DATASET_3}" -BASE_DATA_PATH=/data/Megatron-LM/data +#BASE_DATA_PATH=tests/data/gpt2 +#DATASET=${BASE_DATA_PATH}/meg-gpt2-openwebtext_text_document +#VOCAB_PATH=${BASE_DATA_PATH}/gpt2-tiny-vocab.json +#MERGE_PATH=${BASE_DATA_PATH}/gpt2-tiny-merges.txt + +BASE_DATA_PATH=/vc_data/Megatron-LM/data DATASET=${BASE_DATA_PATH}/indexed_datasets/megatron VOCAB_PATH=${BASE_DATA_PATH}/gpt2-vocab.json MERGE_PATH=${BASE_DATA_PATH}/gpt2-merges.txt @@ -20,40 +25,45 @@ MERGE_PATH=${BASE_DATA_PATH}/gpt2-merges.txt script_path=$(realpath $0) script_dir=$(dirname $script_path) -#CONFIG_JSON="$script_dir/ds_config.json" -CONFIG_JSON="/tmp/ds_config.json" +CONFIG_JSON="$script_dir/ds_config.json" +#CONFIG_JSON="/tmp/ds_config.json" USE_DEEPSPEED=1 ZERO_STAGE=0 - -# Debug #TP=4 #PP=4 -#LAYERS=8 -#HIDDEN=512 -#SEQ=1024 -#GLOBAL_BATCH=128 -#WORKER_STR="-i worker-0" - -TP=1 -PP=1 -DP=2 +# Debug +DEBUG_MODE=0 +if [[ $DEBUG_MODE == 1 ]]; then + LAYERS=4 + HIDDEN=512 + SEQ=512 + EXIT_INTERVAL=3 +else + HIDDEN=1024 + LAYERS=24 + SEQ=1024 + EXIT_INTERVAL=10 +fi + +TP=2 +PP=2 +DP=4 WORLD_SIZE=$((TP*PP*DP)) -HIDDEN=1024 -LAYERS=24 -SEQ=1024 -GLOBAL_BATCH=1 -WORKER_STR="" +GLOBAL_BATCH=4 MICRO_BATCH=1 +TRAIN_ITERS=100000 +CHECKPOINT_PATH=checkpoints/gpt2/tp${TP}_pp${PP}_dp${DP} +LOAD_CHECKPOINT_PATH=checkpoints/gpt2/tp${TP}_pp${PP}_dp${DP} LR=6.0e-4 MIN_LR=6.0e-5 DTYPE="bf16" -EXP_DIR=${HOME}/experiments/results/bf16 -LOG_DIR="${EXP_DIR}/tensorboard/tp${TP}_pp${PP}_dp${DP}_hd${HIDDEN}_nl${LAYERS}_gbsz${GLOBAL_BATCH}_mbsz${MICRO_BATCH}_z${ZERO_STAGE}_LR_${LR}_${MIN_LR}_${DTYPE}_fix3" +EXP_DIR=${HOME}/experiments/results/ckpt_reshape +LOG_DIR="${EXP_DIR}/tensorboard/tp${TP}_pp${PP}_dp${DP}_hd${HIDDEN}_nl${LAYERS}_gbsz${GLOBAL_BATCH}_mbsz${MICRO_BATCH}_z${ZERO_STAGE}_LR_${LR}_${MIN_LR}_${DTYPE}_cont" mkdir -p $LOG_DIR while [[ $# -gt 0 ]] @@ -89,7 +99,7 @@ options=" \ --max-position-embeddings $SEQ \ --micro-batch-size $MICRO_BATCH \ --global-batch-size $GLOBAL_BATCH \ - --train-iters 1000 \ + --train-iters $TRAIN_ITERS \ --lr $LR \ --min-lr $MIN_LR \ --lr-decay-style cosine \ @@ -99,7 +109,7 @@ options=" \ --data-path ${DATASET} \ --vocab-file ${VOCAB_PATH} \ --merge-file ${MERGE_PATH} \ - --save-interval 10000 \ + --save-interval 1000 \ --split 98,2,0 \ --clip-grad 1.0 \ --weight-decay 0.1 \ @@ -108,7 +118,12 @@ options=" \ --init-method-std 0.006 \ --${DTYPE} \ --checkpoint-activations \ - --exit-interval 10000 \ + --exit-interval ${EXIT_INTERVAL} \ + --save ${CHECKPOINT_PATH} \ + --load ${LOAD_CHECKPOINT_PATH} \ + --position-embedding-type alibi \ + --override-lr-scheduler \ + --embed-layernorm \ --tensorboard-dir $LOG_DIR " @@ -151,7 +166,7 @@ cat < $CONFIG_JSON } EOT -WORKER_STR="--num_nodes 1 --num_gpus $WORLD_SIZE" +#WORKER_STR="--num_nodes 1 --num_gpus $WORLD_SIZE" #WORKER_STR="-i worker-0:0,1,2,3" #run_cmd="deepspeed -i worker-0:0,1,2,3 ${DIR}/pretrain_gpt.py $@ ${options}" #run_cmd="deepspeed -i worker-0 ${DIR}/pretrain_gpt.py $@ ${options}" diff --git a/run_universal_bf16.sh b/run_universal_bf16.sh new file mode 100755 index 000000000..7a60c34c1 --- /dev/null +++ b/run_universal_bf16.sh @@ -0,0 +1,180 @@ +#!/bin/bash + + +DIR=`pwd` +DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'` +#mkdir -p $DIR/logs +#mkdir -p /tmp/logs + + +#DATASET_1="" +#DATASET_2="" +#DATASET_3="" +#DATASET="0.2 ${DATASET_1} 0.3 ${DATASET_2} 0.5 ${DATASET_3}" + +#BASE_DATA_PATH=tests/data/gpt2 +#DATASET=${BASE_DATA_PATH}/meg-gpt2-openwebtext_text_document +#VOCAB_PATH=${BASE_DATA_PATH}/gpt2-tiny-vocab.json +#MERGE_PATH=${BASE_DATA_PATH}/gpt2-tiny-merges.txt + +BASE_DATA_PATH=/vc_data/Megatron-LM/data +DATASET=${BASE_DATA_PATH}/indexed_datasets/megatron +VOCAB_PATH=${BASE_DATA_PATH}/gpt2-vocab.json +MERGE_PATH=${BASE_DATA_PATH}/gpt2-merges.txt + + +script_path=$(realpath $0) +script_dir=$(dirname $script_path) +CONFIG_JSON="$script_dir/ds_config.json" +#CONFIG_JSON="/tmp/ds_config.json" + +USE_DEEPSPEED=1 +ZERO_STAGE=0 + +#TP=4 +#PP=4 + +# Debug +DEBUG_MODE=0 +if [[ $DEBUG_MODE == 1 ]]; then + LAYERS=4 + HIDDEN=512 + SEQ=512 + EXIT_INTERVAL=3 +else + HIDDEN=1024 + LAYERS=24 + SEQ=1024 + EXIT_INTERVAL=10 +fi + +TP=2 +PP=2 +DP=4 +WORLD_SIZE=$((TP*PP*DP)) +GLOBAL_BATCH=4 + +MICRO_BATCH=1 +TRAIN_ITERS=100000 +CHECKPOINT_PATH=checkpoints/gpt2/tp${TP}_pp${PP}_dp${DP} +LOAD_CHECKPOINT_PATH=checkpoints/gpt2/tp2_pp2_dp4 + +LR=6.0e-4 +MIN_LR=6.0e-5 +DTYPE="bf16" +EXP_DIR=${HOME}/experiments/results/ckpt_reshape +LOG_DIR="${EXP_DIR}/tensorboard/tp${TP}_pp${PP}_dp${DP}_hd${HIDDEN}_nl${LAYERS}_gbsz${GLOBAL_BATCH}_mbsz${MICRO_BATCH}_z${ZERO_STAGE}_LR_${LR}_${MIN_LR}_${DTYPE}_uni" +mkdir -p $LOG_DIR + +while [[ $# -gt 0 ]] +do +key="$1" +case $key in + --no-deepspeed) + USE_DEEPSPEED=0; + shift + ;; + -z|--zero-stage) + ZERO_STAGE=$2; + shift + ;; + *) + echo "Unknown argument(s)" + usage + exit 1 + shift + ;; +esac +done + + +options=" \ + --tensor-model-parallel-size $TP \ + --pipeline-model-parallel-size $PP \ + --num-layers $LAYERS \ + --hidden-size $HIDDEN \ + --num-attention-heads 32 \ + --seq-length $SEQ \ + --loss-scale 12 \ + --max-position-embeddings $SEQ \ + --micro-batch-size $MICRO_BATCH \ + --global-batch-size $GLOBAL_BATCH \ + --train-iters $TRAIN_ITERS \ + --lr $LR \ + --min-lr $MIN_LR \ + --lr-decay-style cosine \ + --log-interval 1 \ + --eval-iters 40 \ + --eval-interval 10 \ + --data-path ${DATASET} \ + --vocab-file ${VOCAB_PATH} \ + --merge-file ${MERGE_PATH} \ + --save-interval 1000 \ + --split 98,2,0 \ + --clip-grad 1.0 \ + --weight-decay 0.1 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --init-method-std 0.006 \ + --${DTYPE} \ + --checkpoint-activations \ + --exit-interval ${EXIT_INTERVAL} \ + --save ${CHECKPOINT_PATH} \ + --load ${LOAD_CHECKPOINT_PATH} \ + --universal-checkpoint \ + --position-embedding-type alibi \ + --override-lr-scheduler \ + --embed-layernorm \ + --tensorboard-dir $LOG_DIR + " + + +if [[ ${USE_DEEPSPEED} -eq 1 ]]; then + echo "Using DeepSpeed" + options="${options} \ + --deepspeed \ + --deepspeed_config=${CONFIG_JSON} \ + --zero-stage=${ZERO_STAGE} \ + --deepspeed-activation-checkpointing \ + " +fi + + +cat < $CONFIG_JSON +{ + "train_batch_size" : $GLOBAL_BATCH, + "train_micro_batch_size_per_gpu": $MICRO_BATCH, + "steps_per_print": 1, + + "zero_optimization": { + "stage": $ZERO_STAGE + }, + + "bf16": { + "enabled": true + }, + + "fp16": { + "enabled": false, + "loss_scale": 0, + "loss_scale_window": 500, + "hysteresis": 2, + "min_loss_scale": 1, + "initial_scale_power": 12 + }, + + "wall_clock_breakdown" : true +} +EOT + +#WORKER_STR="--num_nodes 1 --num_gpus $WORLD_SIZE" +#WORKER_STR="-i worker-0:0,1,2,3" +#run_cmd="deepspeed -i worker-0:0,1,2,3 ${DIR}/pretrain_gpt.py $@ ${options}" +#run_cmd="deepspeed -i worker-0 ${DIR}/pretrain_gpt.py $@ ${options}" +run_cmd="deepspeed --master_port 29700 $WORKER_STR ${DIR}/pretrain_gpt.py $@ ${options}" + + +echo ${run_cmd} +eval ${run_cmd} + +set +x diff --git a/tests/ds_config_bf16.json b/tests/ds_config_bf16.json new file mode 100644 index 000000000..6afd1f6b2 --- /dev/null +++ b/tests/ds_config_bf16.json @@ -0,0 +1,14 @@ +{ + "train_micro_batch_size_per_gpu": 1, + "train_batch_size": 16, + "gradient_clipping": 1.0, + "zero_optimization": { + "stage": 0 + }, + "bf16": { + "enabled": true + }, + "zero_allow_untested_optimizer": true, + "steps_per_print": 2000, + "wall_clock_breakdown": false +} diff --git a/tests/test_checkpoints.py b/tests/test_checkpoints.py new file mode 100644 index 000000000..fdc41e014 --- /dev/null +++ b/tests/test_checkpoints.py @@ -0,0 +1,298 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import os +import pytest +from pathlib import Path + +from parameterized import parameterized +from megatron.testing_utils import ( + CaptureStdout, + TestCasePlus, + execute_subprocess_async, + get_gpu_count, + require_deepspeed, + require_torch_gpu, + require_torch_multi_gpu, + set_seed +) + +set_seed(42) + + +def parameterized_custom_name_func(func, param_num, param): + # customize the test name generator function as we want both params to appear in the sub-test + # name, as by default it shows only the first param + param_based_name = parameterized.to_safe_name("_to_".join(str(x) for x in param.args)) + return f"{func.__name__}_{param_based_name}" + +params = [ + # TP_PP_DP + ["1_1_1", "1_1_1"], + ["2_1_1", "1_1_1"], + ["1_2_1", "1_1_1"], + ["1_1_2", "1_1_1"], + + ["2_1_1", "2_1_1"], + ["1_1_1", "2_1_1"], + ["1_1_1", "1_2_1"], + ["1_1_1", "1_1_2"], + + ["1_1_2", "1_1_2"], + ["1_1_2", "2_1_1"], + ["1_1_2", "1_2_1"], + + ["1_2_1", "1_2_1"], + ["1_2_1", "2_1_1"], + ["1_2_1", "1_1_2"], + + ["2_1_1", "2_1_1"], + ["2_1_1", "1_2_1"], + ["2_1_1", "1_1_2"], + + ["2_2_2", "1_1_1"], + ["2_2_2", "2_2_2"], + ["1_1_1", "2_2_2"], + + ["1_1_8", "2_2_2"], + +] + +def get_launcher(num_gpus): + # 1. explicitly set --num_nodes=1 just in case these tests end up run on a multi-node setup + # - it won't be able to handle that + return f"deepspeed --num_nodes 1 --num_gpus {num_gpus}".split() + +@require_deepspeed +@require_torch_gpu +class MegDSTestCheckpoints(TestCasePlus): + """ """ + + def setUp(self): + super().setUp() + + # at times magatron fails to build kernels and doesn't remove the lock file, which makes + # subsequent runs hang - so make sure there is no lock when starting the testing + meg_lock_file_path = self.repo_root_dir_str + "/megatron/fused_kernels/build/lock" + if os.path.exists(meg_lock_file_path): + os.unlink(meg_lock_file_path) + + def get_config(self, output_dir, tp_size, pp_size, dp_size): + data_dir = f"{self.data_dir}/gpt2" + + num_gpus = pp_size * tp_size * dp_size + print(f"Using {num_gpus} GPUs") + + n_samples = 300 # about 56 iterations + + exit_interval = 20 # some samples in the first half and then some more in the 2nd half after resume + seq_len = 128 + + # XXX: for now while testing shapes make it really short and fast + exit_interval = 1 + seq_len = 8 + + + # common/shared configs + + ds_args = f""" + --deepspeed + --deepspeed_config {self.test_file_dir_str}/ds_config_bf16.json + --zero-stage 0 + --deepspeed-activation-checkpointing + """.split() + + args = f""" + --tensor-model-parallel-size {tp_size} + --pipeline-model-parallel-size {pp_size} + --distributed-backend nccl + + --log-interval 1 + --save-interval 1 + --eval-interval 10 + --eval-iters 1 + --checkpoint-activations + --partition-activations + --exit-interval {exit_interval} + + --merge-file {data_dir}/gpt2-tiny-merges.txt + --vocab-file {data_dir}/gpt2-tiny-vocab.json + --save {output_dir}/checkpoints + --load {output_dir}/checkpoints + --data-path {data_dir}/meg-gpt2-openwebtext_text_document + --tensorboard-dir {output_dir}/tensorboard + --tensorboard-queue-size 5 + --log-timers-to-tensorboard + --log-batch-size-to-tensorboard + --log-validation-ppl-to-tensorboard + + --num-layers 2 + --hidden-size 8 + --num-attention-heads 2 + --seq-length {seq_len} + --max-position-embeddings 8 + --micro-batch-size 1 + --global-batch-size 16 + --train-samples {n_samples} + + --embed-layernorm + --position-embedding-type alibi + + --optimizer adam + --adam-beta1 0.9 + --adam-beta2 0.95 + --adam-eps 1e-8 + --lr 1e-4 + --lr-warmup-samples 5 + --lr-decay-samples 6 + --clip-grad 1.0 + --weight-decay 1e-1 + --bf16 + + --log-level debug + --log-level-replica info + """.split() + + + # XXX: fails to handle: + #--embed-layernorm + # +# stderr: RuntimeError: Error(s) in loading state_dict for VocabParallelEmbedding: +# stderr: size mismatch for norm.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]). +# stderr: size mismatch for norm.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]). + + return args, ds_args, num_gpus + + + def train_checkpoint(self, output_dir, tp_size=1, pp_size=1, dp_size=1): + src_dir = self.src_dir + script = [f"{src_dir}/pretrain_gpt.py"] + + args, ds_args, num_gpus = self.get_config(output_dir, tp_size, pp_size, dp_size) + launcher = get_launcher(num_gpus) + cmd = launcher + script + args + ds_args + # keep for quick debug + #print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die + + # 1. test training from scratch (no checkpoint) + with CaptureStdout() as cs: + execute_subprocess_async(cmd, env=self.get_env()) + + # test deepspeed is running + self.assertIn("DeepSpeed info", cs.out) + + # test reports + self.assertIn("consumed samples", cs.out) + + # test there should be no checkpoint this round + self.assertIn(f"Unable to find latest file at {output_dir}/checkpoints/latest", cs.out) + + # test checkpoint saving + self.assertIn("successfully saved checkpoint at iteration", cs.out) + + def convert_checkpoint_to_universal(self, output_dir, step): + cmd = f""" + python tools/convert_checkpoint/ds_to_universal.py + --input_folder {output_dir}/checkpoints/global_step{step} + --output_folder {output_dir}/checkpoints/global_step{step}_universal + """.split() + # keep for quick debug + # print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die + + with CaptureStdout() as cs: + execute_subprocess_async(cmd, env=self.get_env()) + + self.assertIn("Convert DeepSpeed Checkpoint to Universal Checkpoint", cs.out) + + def resume_from_checkpoint(self, output_dir, tp_size=1, pp_size=1, dp_size=1): + src_dir = self.src_dir + script = [f"{src_dir}/pretrain_gpt.py"] + + args, ds_args, num_gpus = self.get_config(output_dir, tp_size, pp_size, dp_size) + launcher = get_launcher(num_gpus) + cmd = launcher + script + args + ds_args + # keep for quick debug + # print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die + + with CaptureStdout() as cs: + execute_subprocess_async(cmd, env=self.get_env()) + + # test checkpoint loading + self.assertIn(f"successfully loaded checkpoint from {output_dir}/checkpoints", cs.out) + + # test reports + self.assertIn("consumed samples", cs.out) + + # test checkpoint saving + self.assertIn("successfully saved checkpoint at iteration", cs.out) + + def resume_from_universal_checkpoint(self, output_dir, tp_size=1, pp_size=1, dp_size=1): + src_dir = self.src_dir + script = [f"{src_dir}/pretrain_gpt.py"] + + args, ds_args, num_gpus = self.get_config(output_dir, tp_size, pp_size, dp_size) + launcher = get_launcher(num_gpus) + cmd = launcher + script + args + ds_args + ["--universal-checkpoint"] + # keep for quick debug + #print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die + + with CaptureStdout() as cs: + execute_subprocess_async(cmd, env=self.get_env()) + + # test checkpoint loading + self.assertIn(f"successfully loaded checkpoint from {output_dir}/checkpoints", cs.out) + + # test reports + self.assertIn("consumed samples", cs.out) + + # test checkpoint saving + self.assertIn("successfully saved checkpoint at iteration", cs.out) + + + @require_torch_multi_gpu + @parameterized.expand(params, name_func=parameterized_custom_name_func) + def test_checkpoint_reshaping_main(self, src, tgt): + # this test needs at least 2 gpus - if there are more gpus it will do more extensive testing + + tp_size_src, pp_size_src, dp_size_src = list(map(int, src.split('_'))) + tp_size_tgt, pp_size_tgt, dp_size_tgt = list(map(int, tgt.split('_'))) + + n_gpus = get_gpu_count() + n_gpus_src = tp_size_src * pp_size_src * dp_size_src + n_gpus_tgt = tp_size_tgt * pp_size_tgt * dp_size_tgt + + if n_gpus_src > n_gpus: + pytest.skip(f"the test requires {n_gpus_src} gpus for source topology but have only {n_gpus}") + if n_gpus_tgt > n_gpus: + pytest.skip(f"the test requires {n_gpus_tgt} gpus for target topology but have only {n_gpus}") + + output_dir = self.get_auto_remove_tmp_dir("./xxx", after=False) + + # 1. train with initial topology defined in the first arg of params + self.train_checkpoint(output_dir, tp_size=tp_size_src , pp_size=pp_size_src , dp_size=dp_size_src ) + + # 2. convert checkpoint to universal checkpoint (topology ) + self.convert_checkpoint_to_universal(output_dir=output_dir, step=1) + + # 3. check we can resume training from a reshaped checkpoint to the target topology - the last arg of params + self.resume_from_universal_checkpoint(output_dir, tp_size=tp_size_tgt, pp_size=pp_size_tgt, dp_size=dp_size_tgt) + + + @require_torch_multi_gpu + def test_checkpoint_reshaping_empty_dir(self): + + output_dir = self.get_auto_remove_tmp_dir() # "./xxx", after=False) + with self.assertRaises(RuntimeError) as context: + self.convert_checkpoint_to_universal(output_dir=output_dir, step=1) diff --git a/tools/convert_checkpoint/deepspeed_checkpoint.py b/tools/convert_checkpoint/deepspeed_checkpoint.py deleted file mode 100644 index 52dff44f2..000000000 --- a/tools/convert_checkpoint/deepspeed_checkpoint.py +++ /dev/null @@ -1,195 +0,0 @@ -import os -from typing import Dict -import torch - -ZERO_FILE_PREFIX = 'zero_pp_rank_' -LAYER_FILE_PREFIX = 'layer_' -MP_RANK_FILE_PREFIX = 'mp_rank_' -EMBEDDING_LAYER_INDEX = 0 -FINAL_LAYER_NORM_INDEX = -1 -ARGS_KEY = 'args' -ITERATION_KEY = 'iteration' -SEQUENTIAL_LAYERS = [ - 'input_layernorm.weight', 'input_layernorm.bias', - 'self_attention.dense.bias', - 'post_attention_layernorm.weight', 'post_attention_layernorm.bias', - 'mlp.dense_4h_to_h.bias', - 'position_embeddings.weight' -] - -LAYER_CONCAT_DIM = { - 'self_attention.dense.weight': 1, - 'mlp.dense_4h_to_h.weight': 1 -} - -class DeepSpeedCheckpoint(object): - def __init__(self, dir, tp_degree=None, pp_degree=None): - self.dir = dir - self.file_list = self._get_files(dir) - self.zero_files = self._get_files_with_prefix(self.file_list, ZERO_FILE_PREFIX) - self.layer_files = self._get_files_with_prefix(self.file_list, LAYER_FILE_PREFIX) - self.mp_rank_files = self._get_files_with_prefix(self.file_list, MP_RANK_FILE_PREFIX) - self.layer_keys = self._get_layer_keys() - self.layer_count = len(self.layer_keys) - self.original_tp_degree = len(self._get_files_with_prefix(self.layer_files, f'{LAYER_FILE_PREFIX}01')) - self.original_pp_degree = len(self.mp_rank_files) // self.original_tp_degree - self.dp_degree = len(self.zero_files) // (self.original_pp_degree * self.original_tp_degree) - self.tp_degree = self.original_tp_degree if tp_degree is None else tp_degree - self.pp_degree = self.original_pp_degree if pp_degree is None else pp_degree - self.global_state = {} - - self._sanity_check() - self.pp_to_transformer_map = self._build_pp_transformer_map() - self.transformer_file_map = self._build_transformer_file_map() - self.tp_to_embedding_map = self._build_tp_other_layer_map(EMBEDDING_LAYER_INDEX) - self.tp_to_final_norm_map = self._build_tp_other_layer_map(FINAL_LAYER_NORM_INDEX) - self._build_global_state() - - - - def show_tp_embedding_map(self): - self._dump_mapping(self.tp_to_embedding_map, 'tp_to_embedding_layers') - - def show_tp_final_norm_map(self): - self._dump_mapping(self.tp_to_final_norm_map, 'tp_to_final_norm_layers') - - def show_pp_tranformer_map(self): - self._dump_mapping(self.pp_to_transformer_map, 'pp_to_tranformer_layers') - - def show_transformer_file_map(self): - self._dump_mapping(self.transformer_file_map, 'rank_to_tranformer_files') - - def _build_global_state(self): - sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu')) - self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0) - self.global_state[ARGS_KEY] = sd.get(ARGS_KEY, None) - - def get_iteration(self): - if not ITERATION_KEY in self.global_state: - sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu')) - self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0) - - return self.global_state[ITERATION_KEY] - - def get_embedding_state(self, tp_index: int) -> Dict: - assert tp_index in self.tp_to_embedding_map.keys() - sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in self.tp_to_embedding_map[tp_index]] - sd = self._merge_state_dicts(sd_list) - return sd - - def get_args(self): - if not ARGS_KEY in self.global_state: - sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu')) - self.global_state[ARGS_KEY] = sd.get(ARGS_KEY, None) - - return self.global_state[ARGS_KEY] - - - def get_transformer_state(self, tp_index: int, pp_index: int) -> list: - assert tp_index < self.tp_degree - assert pp_index < self.pp_degree - t_list = [] - for fname_list in self.transformer_file_map[(tp_index, pp_index)]: - sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in fname_list] - sd = self._merge_state_dicts(sd_list) - t_list.append(sd) - return t_list - - def get_final_norm_state(self, tp_index:int) -> Dict: - assert tp_index in self.tp_to_final_norm_map.keys() - sd = torch.load(self.tp_to_final_norm_map[tp_index][0], map_location=torch.device('cpu')) - return sd - - def _build_tp_other_layer_map(self, layer_index:int): - assert layer_index < len(self.layer_files) - layer_files = self._get_files_with_prefix(self.layer_files, self.layer_keys[layer_index]) - layer_file_partitions = self._partition_data(layer_files, self.tp_degree) - data_map = {i:flist for i, flist in enumerate(layer_file_partitions)} - return data_map - - def _build_pp_transformer_map(self): - data_map = {} - transformer_layers = self.layer_keys[1:-1] - layers_per_pp = len(transformer_layers) // self.pp_degree - data_map = {i:transformer_layers[i*layers_per_pp:(i+1)*layers_per_pp] for i in range(0, self.pp_degree)} - return data_map - - def _dump_mapping(self, data_map, map_tag = None): - if map_tag is not None: - print(f'Dump mapping: {map_tag}') - for k, v in data_map.items(): - print(f'{k} = {v}') - - def _build_transformer_file_map(self): - transformer_layer_keys = self.layer_keys[1:-1] - file_map = {} - layers_per_pp = len(transformer_layer_keys) // self.pp_degree - for key_index, layer_key in enumerate(transformer_layer_keys): - pp_index = key_index // layers_per_pp - layer_files = self._get_files_with_prefix(self.layer_files, layer_key) - layer_file_partitions = self._partition_data(layer_files, self.tp_degree) - for tp_index in range(self.tp_degree): - map_key = (tp_index, pp_index) - if not map_key in file_map.keys(): - file_map[map_key] = [] - file_map[map_key].append(layer_file_partitions[tp_index]) - - return file_map - - def _sanity_check(self): - assert len(self.mp_rank_files) % self.tp_degree == 0 - assert len(self.zero_files) % (self.pp_degree * self.tp_degree) == 0 - assert len(self.layer_keys) > 2 - - # XXX: disable for now, since this fails when using: - # --pp-partition-method 'type:transformer|embedding' - # so if it can detect this flag somehow it then should validate: - # assert (len(self.layer_keys)) % self.pp_degree == 0 - # the original: - # assert (len(self.layer_keys) - 2) % self.pp_degree == 0 - - def _get_files_with_prefix(self, all_files, prefix): - file_list = [] - for file_path in all_files: - _, fname = os.path.split(file_path) - if fname.startswith(prefix): - file_list.append(file_path) - - return sorted(file_list) - - def validate_files(self): - for file in self.file_list: - if not os.path.isfile(file): - print(f'Error: {file} is not existent') - - def _get_files(self, dir): - file_list = [] - for root, dirs, files in os.walk(dir): - for file in files: - file_list.append(os.path.join(root, file)) - return file_list - - def _get_layer_keys(self): - key_set = set() - key_len = len(LAYER_FILE_PREFIX) + 2 - for file_path in self.layer_files: - _, fname = os.path.split(file_path) - key_set.add(fname[:key_len]) - return sorted(list(key_set)) - - def _partition_data(self, data_list, num_partitions): - num_elems = len(data_list) - assert num_elems % num_partitions == 0 - partition_size = num_elems // num_partitions - partitions_list = [data_list[i:i+partition_size] for i in range(0, num_elems, partition_size)] - return partitions_list - - def _merge_state_dicts(self, sd_list): - merged_sd = {} - for key in sd_list[0].keys(): - if not key in SEQUENTIAL_LAYERS: - cat_dim = LAYER_CONCAT_DIM.get(key, 0) - merged_sd[key] = torch.cat([sd[key] for sd in sd_list], dim=cat_dim) - else: - merged_sd[key] = sd_list[0][key] - return merged_sd diff --git a/tools/convert_checkpoint/deepspeed_to_deepspeed.py b/tools/convert_checkpoint/deepspeed_to_deepspeed.py new file mode 100644 index 000000000..8d484e88d --- /dev/null +++ b/tools/convert_checkpoint/deepspeed_to_deepspeed.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python +import sys +import argparse +import os +import torch + +from pathlib import Path + +# insert megatron's root dir into sys.path +root_repo_path = str(Path(__file__).resolve().parents[2]) +if root_repo_path not in sys.path: + sys.path.insert(0, root_repo_path) + +from megatron.tokenizer.tokenizer import _vocab_size_with_padding +from deepspeed.checkpoint.deepspeed_checkpoint import ( + ARGS_KEY, + CHECKPOINT_INFO_KEY, +) + +from deepspeed.checkpoint import ( + DeepSpeedCheckpoint, + get_model_ckpt_name_for_rank, + get_zero_ckpt_name_for_rank, + get_layer_ckpt_name_for_rank +) + +CHECKPOINT_FILE_SUFFIX = '_model_states.pt' +MP_WORLD_SIZE ='mp_world_size' +WORD_EMBEDDINGS_KEY = 'word_embeddings.weight' +ORIGINAL_VOCAB_SIZE = 'original_vocab_size' +PADDED_VOCAB_SIZE = 'padded_vocab_size' + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument('--input_folder', + default=None, + type=str, + help='Input DeepSpeed Checkpoint folder') + parser.add_argument('--output_folder', + default=None, + type=str, + help='Output Megatron checkpoint folder') + parser.add_argument('--target_tp', + default=None, + type=int, + help='Target TP degree') + parser.add_argument('--target_pp', + default=None, + type=int, + help='Target PP degree') + parser.add_argument('--target_dp', + default=None, + type=int, + help='Target DP degree') + args = parser.parse_args() + print(f'args = {args}') + return args + + + +def _save_checkpoint(file_path, chkpt_sd): + dir, _ = os.path.split(file_path) + os.makedirs(dir, exist_ok=True) + torch.save(chkpt_sd, file_path) + + +def _create_transformer_layer_checkpoint(ds_checkpoint, base_folder, tp_index, pp_index): + sd_list = ds_checkpoint.get_transformer_state(tp_index, pp_index) + layer_id_list = ds_checkpoint.get_pp_transformer_map(pp_index) + assert len(sd_list) == len(layer_id_list) + for sd, layer_id in zip(sd_list, layer_id_list): + ckpt_path = get_layer_ckpt_name_for_rank( + base_folder=base_folder, + layer_id=layer_id, + tp_rank=tp_index) + _save_checkpoint(ckpt_path, sd) + + +def _strip_vocab_padding(ds_checkpoint, padded_vocab_tensor): + target_args = ds_checkpoint.get_args() + checkpoint_info = ds_checkpoint.get_checkpoint_info() + target_args.tensor_model_parallel_size = ds_checkpoint.tp_degree + target_args.padded_vocab_size = _vocab_size_with_padding(checkpoint_info[ORIGINAL_VOCAB_SIZE], target_args) + assert target_args.padded_vocab_size <= padded_vocab_tensor.numel() + checkpoint_info[PADDED_VOCAB_SIZE] = target_args.padded_vocab_size + unpadded_vocab_tensor = torch.narrow(padded_vocab_tensor, 0, 0, target_args.padded_vocab_size) + return unpadded_vocab_tensor.clone() + + +def _create_embedding_layer_checkpoint(ds_checkpoint, base_folder, tp_index): + sd = ds_checkpoint.get_embedding_state(tp_index) + if ds_checkpoint.is_change_tp_degree(): + sd[WORD_EMBEDDINGS_KEY] = _strip_vocab_padding(ds_checkpoint, sd[WORD_EMBEDDINGS_KEY]) + layer_id = ds_checkpoint.get_embedding_layer_id() + ckpt_path = get_layer_ckpt_name_for_rank( + base_folder=base_folder, + tp_rank=tp_index, + layer_id=layer_id) + _save_checkpoint(ckpt_path, sd) + + +def _create_final_norm_layer_checkpoint(ds_checkpoint, base_folder, tp_index): + sd = ds_checkpoint.get_final_norm_state(tp_index) + layer_id = ds_checkpoint.get_final_norm_layer_id() + ckpt_path = get_layer_ckpt_name_for_rank( + base_folder=base_folder, + tp_rank=tp_index, + layer_id=layer_id) + _save_checkpoint(ckpt_path, sd) + + +def _create_2d_parallel_checkpoint(ds_checkpoint, base_folder, tp_index, + pp_index): + sd = ds_checkpoint.get_2d_parallel_state(tp_index=tp_index, + pp_index=pp_index) + sd[MP_WORLD_SIZE] = ds_checkpoint.tp_degree + file_id = pp_index * ds_checkpoint.tp_degree + tp_index + ckpt_path = get_model_ckpt_name_for_rank(base_folder, f'{file_id:02d}') + + # Adjust specific fields + sd[ARGS_KEY] = ds_checkpoint.get_args() + sd[ARGS_KEY].tensor_model_parallel_size = ds_checkpoint.tp_degree + sd[ARGS_KEY].pipeline_model_parallel_size = ds_checkpoint.pp_degree + sd[CHECKPOINT_INFO_KEY][PADDED_VOCAB_SIZE] = sd[ARGS_KEY].padded_vocab_size + _save_checkpoint(ckpt_path, sd) + + +def _create_zero_checkpoint(ds_checkpoint, base_folder, dp_index, pp_index, tp_index): + _2d_rank = (pp_index * ds_checkpoint.tp_degree) + tp_index + sd = ds_checkpoint.get_zero_checkpoint_state( + pp_index=pp_index, + tp_index=tp_index, + dp_index=dp_index) + + ckpt_path = get_zero_ckpt_name_for_rank(base_folder=base_folder, + dp_rank=dp_index, + mp_rank=_2d_rank) + _save_checkpoint(ckpt_path, sd) + + +def _create_latest_file(base_folder, file_name, latest_tag): + file_path = os.path.join(base_folder, file_name) + os.makedirs(base_folder, exist_ok=True) + with open(file_path, 'w') as f: + f.write(str(latest_tag)) + + +def main(): + print(f'Convert DeepSpeed Checkpoint to DeepSpeed Checkpoint') + + args = parse_arguments() + print( + f'Converting DeepSpeed checkpoint in {args.input_folder} to DeepSpeed checkpoint in {args.output_folder}' + ) + + ds_checkpoint = DeepSpeedCheckpoint( + args.input_folder, + args.target_tp, + args.target_pp, + args.target_dp) + iteration = ds_checkpoint.get_iteration() + latest_tag = f'global_step{iteration}' + _create_latest_file(args.output_folder, + 'latest_checkpointed_iteration.txt', iteration) + _create_latest_file(args.output_folder, 'latest', latest_tag) + base_folder = os.path.join(args.output_folder, latest_tag) + + for i in range(ds_checkpoint.tp_degree): + _create_embedding_layer_checkpoint(ds_checkpoint, base_folder, i) + _create_final_norm_layer_checkpoint(ds_checkpoint, base_folder, i) + + for j in range(ds_checkpoint.pp_degree): + _create_transformer_layer_checkpoint(ds_checkpoint, base_folder, i, j) + _create_2d_parallel_checkpoint(ds_checkpoint, base_folder, i, j) + + for i in range(ds_checkpoint.dp_degree): + for j in range(ds_checkpoint.pp_degree): + for k in range(ds_checkpoint.tp_degree): + _create_zero_checkpoint(ds_checkpoint, base_folder, i, j, k) + + +if __name__ == "__main__": + main() diff --git a/tools/convert_checkpoint/deepspeed_to_megatron.py b/tools/convert_checkpoint/deepspeed_to_megatron.py index 017036af4..74e5ca7c9 100755 --- a/tools/convert_checkpoint/deepspeed_to_megatron.py +++ b/tools/convert_checkpoint/deepspeed_to_megatron.py @@ -13,18 +13,34 @@ ENCODER_KEY = 'encoder' WORD_EMBEDDINGS_FOR_HEAD_KEY = 'word_embeddings_for_head' WORD_EMBEDDINGS_KEY = 'word_embeddings' -FINAL_LAYER_NORM_KEY ='final_layernorm' +FINAL_LAYER_NORM_KEY = 'final_layernorm' CHECKPOINT_VERSION_KEY = 'checkpoint_version' CHECKPOINT_VERSION_VALUE = 3.0 ITERATION_KEY = 'iteration' + def parse_arguments(): parser = argparse.ArgumentParser() - parser.add_argument('--input_folder', default=None, type=str, help='Input DeepSpeed Checkpoint folder') - parser.add_argument('--output_folder', default=None, type=str, help='Output Megatron checkpoint folder') - parser.add_argument('--target_tp', default=1, type=int, help='Target TP degree') - parser.add_argument('--target_pp', default=1, type=int, help='Target PP degree') - parser.add_argument('--for_release', action='store_true', help='Convert for release purpose, reset some (progress) counters.') + parser.add_argument('--input_folder', + default=None, + type=str, + help='Input DeepSpeed Checkpoint folder') + parser.add_argument('--output_folder', + default=None, + type=str, + help='Output Megatron checkpoint folder') + parser.add_argument('--target_tp', + default=1, + type=int, + help='Target TP degree') + parser.add_argument('--target_pp', + default=1, + type=int, + help='Target PP degree') + parser.add_argument( + '--for_release', + action='store_true', + help='Convert for release purpose, reset some (progress) counters.') args = parser.parse_args() print(f'args = {args}') return args @@ -39,6 +55,7 @@ def _convert_ds_transformer_state(sd_list): return new_sd + def _create_checkpoint_paths(base_folder, iteration, tp_degree, pp_degree): path_list = [] iter_folder = f'iter_{iteration:07d}' @@ -47,18 +64,18 @@ def _create_checkpoint_paths(base_folder, iteration, tp_degree, pp_degree): for j in range(0, pp_degree): rank_folder = f'mp_rank_{i:02d}' if pp_degree == 1 else f'mp_rank_{i:02d}_{j:03d}' ckpt_path = os.path.join(rank_folder, 'model_optim_rng.pt') - path_list[i].append(os.path.join(base_folder, iter_folder, ckpt_path)) + path_list[i].append( + os.path.join(base_folder, iter_folder, ckpt_path)) return path_list def _create_megatron_dict(): - language_model_dict = { - EMBEDDING_KEY: {}, - ENCODER_KEY: {} - } + language_model_dict = {EMBEDDING_KEY: {}, ENCODER_KEY: {}} megatron_dict = { - MODEL_KEY: {LANGUGAGE_MODEL_KEY: language_model_dict}, + MODEL_KEY: { + LANGUGAGE_MODEL_KEY: language_model_dict + }, CHECKPOINT_VERSION_KEY: CHECKPOINT_VERSION_VALUE } return megatron_dict @@ -78,7 +95,11 @@ def _renest_sd(sd): return new_sd -def _create_rank_checkpoint(ds_checkpoint, checkpoint_path, tp_index, pp_index, for_release=False): +def _create_rank_checkpoint(ds_checkpoint, + checkpoint_path, + tp_index, + pp_index, + for_release=False): meg_encoder_sd = OrderedDict() meg_embedding_sd = OrderedDict() meg_embedding_for_head_sd = OrderedDict() @@ -92,7 +113,7 @@ def _create_rank_checkpoint(ds_checkpoint, checkpoint_path, tp_index, pp_index, if pp_index == 0: meg_embedding_sd.update(nested_embedding_sd) - if pp_index == ds_checkpoint.pp_degree -1: + if pp_index == ds_checkpoint.pp_degree - 1: for key, value in embedding_sd.items(): if key.startswith(WORD_EMBEDDINGS_KEY): fields = key.split('.') @@ -101,7 +122,10 @@ def _create_rank_checkpoint(ds_checkpoint, checkpoint_path, tp_index, pp_index, meg_embedding_for_head_sd[new_key] = value final_norm_sd = ds_checkpoint.get_final_norm_state(tp_index) - new_final_norm_sd = {f'{FINAL_LAYER_NORM_KEY}.{key}': value for key, value in final_norm_sd.items()} + new_final_norm_sd = { + f'{FINAL_LAYER_NORM_KEY}.{key}': value + for key, value in final_norm_sd.items() + } meg_encoder_sd.update(new_final_norm_sd) checkpoint_sd = _create_megatron_dict() @@ -109,15 +133,19 @@ def _create_rank_checkpoint(ds_checkpoint, checkpoint_path, tp_index, pp_index, iteration = ds_checkpoint.get_iteration() checkpoint_sd[ITERATION_KEY] = iteration if pp_index == 0: - checkpoint_sd[MODEL_KEY][LANGUGAGE_MODEL_KEY][EMBEDDING_KEY] = meg_embedding_sd + checkpoint_sd[MODEL_KEY][LANGUGAGE_MODEL_KEY][ + EMBEDDING_KEY] = meg_embedding_sd checkpoint_sd[MODEL_KEY][LANGUGAGE_MODEL_KEY][ENCODER_KEY] = meg_encoder_sd - if pp_index == ds_checkpoint.pp_degree -1: - checkpoint_sd[MODEL_KEY][WORD_EMBEDDINGS_FOR_HEAD_KEY] = meg_embedding_for_head_sd + if pp_index == ds_checkpoint.pp_degree - 1: + checkpoint_sd[MODEL_KEY][ + WORD_EMBEDDINGS_FOR_HEAD_KEY] = meg_embedding_for_head_sd checkpoint_sd[ARGS_KEY] = ds_checkpoint.get_args() # Adjust specific fields - checkpoint_sd[ARGS_KEY].tensor_model_parallel_size = ds_checkpoint.tp_degree - checkpoint_sd[ARGS_KEY].pipeline_model_parallel_size = ds_checkpoint.pp_degree + checkpoint_sd[ + ARGS_KEY].tensor_model_parallel_size = ds_checkpoint.tp_degree + checkpoint_sd[ + ARGS_KEY].pipeline_model_parallel_size = ds_checkpoint.pp_degree if for_release: checkpoint_sd[ARGS_KEY].consumed_train_samples = 0 checkpoint_sd[ARGS_KEY].consumed_valid_samples = 0 @@ -131,20 +159,27 @@ def _create_latest_file(base_folder, iteration): with open(file_path, 'w') as f: f.write(str(iteration)) + def main(): print(f'Convert DeepSpeed Checkpoint to Megatron Checkpoint') args = parse_arguments() - print(f'Converting DeepSpeed checkpoint in {args.input_folder} to Megatron checkpoint in {args.output_folder}') + print( + f'Converting DeepSpeed checkpoint in {args.input_folder} to Megatron checkpoint in {args.output_folder}' + ) - ds_checkpoint = DeepSpeedCheckpoint(args.input_folder, args.target_tp, args.target_pp) + ds_checkpoint = DeepSpeedCheckpoint(args.input_folder, args.target_tp, + args.target_pp) iteration = ds_checkpoint.get_iteration() _create_latest_file(args.output_folder, iteration) - checkpoint_paths = _create_checkpoint_paths(args.output_folder, iteration, ds_checkpoint.tp_degree, ds_checkpoint.pp_degree) + checkpoint_paths = _create_checkpoint_paths(args.output_folder, iteration, + ds_checkpoint.tp_degree, + ds_checkpoint.pp_degree) for i in range(0, ds_checkpoint.tp_degree): for j in range(0, ds_checkpoint.pp_degree): sd = _create_rank_checkpoint(ds_checkpoint, i, j, args.for_release) _save_checkpoint(checkpoint_paths[i][j], sd) + if __name__ == "__main__": main() diff --git a/tools/convert_checkpoint/deepspeed_to_transformers.py b/tools/convert_checkpoint/deepspeed_to_transformers.py index 667695026..015f63a94 100755 --- a/tools/convert_checkpoint/deepspeed_to_transformers.py +++ b/tools/convert_checkpoint/deepspeed_to_transformers.py @@ -3,31 +3,37 @@ import os import torch import json - -from deepspeed_checkpoint import DeepSpeedCheckpoint +import sys +from pathlib import Path + +# insert megatron's root dir into sys.path +root_repo_path = str(Path(__file__).resolve().parents[2]) +if root_repo_path not in sys.path: + sys.path.insert(0, root_repo_path) + +from deepspeed.checkpoint import DeepSpeedCheckpoint from deepspeed_to_megatron import _create_rank_checkpoint, parse_arguments # the import was tested to work with this version # https://github.com/huggingface/transformers/commit/0af901e83 if it diverges we may consider # copying that version here instead -from transformers.models.megatron_gpt2.convert_megatron_gpt2_checkpoint import ( - convert_megatron_checkpoint, -) -from transformers import GPT2Config, AutoTokenizer +from transformers.models.megatron_gpt2.convert_megatron_gpt2_checkpoint import convert_megatron_checkpoint +from transformers import GPT2Config def main(): + # this first part comes mainly from deepspeed_to_megatron.main args = parse_arguments() print( - f"Converting DeepSpeed checkpoint in {args.input_folder} to HF Transformers checkpoint in {args.output_folder}" + f'Converting DeepSpeed checkpoint in {args.input_folder} to HF Transformers checkpoint in {args.output_folder}' ) - ds_checkpoint = DeepSpeedCheckpoint( - args.input_folder, args.target_tp, args.target_pp - ) - ds_args = ds_checkpoint.get_args() - input_state_dict = _create_rank_checkpoint(ds_checkpoint, 0, 0, args.for_release) + ds_checkpoint = DeepSpeedCheckpoint(args.input_folder, args.target_tp, + args.target_pp) + iteration = ds_checkpoint.get_iteration() + input_state_dict = _create_rank_checkpoint(ds_checkpoint, 0, 0, + args.for_release) # the 2nd part comes from transformers.models.megatron_gpt2.convert_megatron_gpt2_checkpoint.main # Spell out all parameters in case the defaults change. @@ -59,13 +65,14 @@ def main(): # Convert. print("Converting to HF Checkpoint") - output_state_dict = convert_megatron_checkpoint(args, input_state_dict, config) + output_state_dict = convert_megatron_checkpoint(args, input_state_dict, + config) basename = args.output_folder os.makedirs(basename, exist_ok=True) # Print the structure of converted state dict. - # if args.print_checkpoint_structure: + #if args.print_checkpoint_structure: # recursive_print(None, output_state_dict) # Store the config to file. @@ -73,20 +80,6 @@ def main(): output_config = config.to_dict() output_config["architectures"] = ["GPT2LMHeadModel"] output_config["model_type"] = "gpt2" - - # Add tokenizer class info to config.json - # see https://github.com/huggingface/transformers/issues/13906) - tokenizer_type = ds_args.tokenizer_type - if tokenizer_type == "GPT2BPETokenizer": - tokenizer_model_name = "gpt2" - elif tokenizer_type == "PretrainedFromHF": - tokenizer_model_name = ds_args.tokenizer_name_or_path - else: - raise ValueError(f"Unrecognized tokenizer_type {tokenizer_type}") - tokenizer = AutoTokenizer.from_pretrained(tokenizer_model_name) - tokenizer_class = type(tokenizer).__name__ - output_config["tokenizer_class"] = tokenizer_class - print(f'Saving config to "{output_config_file}"') with open(output_config_file, "w") as f: json.dump(output_config, f) @@ -96,9 +89,7 @@ def main(): print(f'Saving checkpoint to "{output_checkpoint_file}"') torch.save(output_state_dict, output_checkpoint_file) - # Save tokenizer based on args - print(f"Adding {tokenizer_class} tokenizer files") - tokenizer.save_pretrained(basename) + print("Now add tokenizer files and upload to the hub") if __name__ == "__main__": diff --git a/tools/convert_checkpoint/ds_to_universal.py b/tools/convert_checkpoint/ds_to_universal.py new file mode 100755 index 000000000..9a5dd1154 --- /dev/null +++ b/tools/convert_checkpoint/ds_to_universal.py @@ -0,0 +1,336 @@ +#!/usr/bin/env python + +from collections import OrderedDict +from copy import deepcopy +from email.policy import default +from functools import partial +from pathlib import Path +from pprint import pprint +import argparse +import glob +import itertools +import logging +import multiprocessing +import os +import re +import shutil +import sys +import torch +import tqdm + +# insert megatron's root dir into sys.path +root_repo_path = str(Path(__file__).resolve().parents[2]) +if root_repo_path not in sys.path: + sys.path.insert(0, root_repo_path) + + +from deepspeed.checkpoint import DeepSpeedCheckpoint + +MODEL_KEY = 'model' +ARGS_KEY = 'args' +LANGUGAGE_MODEL_KEY = 'language_model' +EMBEDDING_KEY = 'embedding' +ENCODER_KEY = 'encoder' +WORD_EMBEDDINGS_FOR_HEAD_KEY = 'word_embeddings_for_head' +WORD_EMBEDDINGS_KEY = 'word_embeddings' +FINAL_LAYER_NORM_KEY = 'final_layernorm' +CHECKPOINT_VERSION_KEY = 'checkpoint_version' +CHECKPOINT_VERSION_VALUE = 3.0 +ITERATION_KEY = 'iteration' + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument('--input_folder', + type=str, + help='Input DeepSpeed Checkpoint folder') + parser.add_argument('--output_folder', + type=str, + help='Output Megatron checkpoint folder') + parser.add_argument('--target_tp', + default=1, + type=int, + help='Target TP degree') + parser.add_argument('--target_pp', + default=1, + type=int, + help='Target PP degree') + parser.add_argument('--num_extract_workers', + default=4, + type=int, + help='How many parallel processes to extract zero shards') + parser.add_argument('--num_merge_workers', + default=2, + type=int, + help='How many parallel processes to merge tp slices (more memory intensive, use much fewer than --num_extract_workers))') + parser.add_argument( + '--for_release', + action='store_true', + help='Convert for release purpose, reset some (progress) counters.') + args = parser.parse_args() + print(f'args = {args}') + return args + + +def _convert_ds_transformer_state(sd_list): + new_sd = OrderedDict() + for i, sd in enumerate(sd_list): + for key, value in sd.items(): + new_key = f'layers.{i}.{key}' + new_sd[new_key] = value + + return new_sd + + +def _create_checkpoint_paths(base_folder, iteration, tp_degree, pp_degree): + path_list = [] + iter_folder = f'iter_{iteration:07d}' + for i in range(0, tp_degree): + path_list.append([]) + for j in range(0, pp_degree): + rank_folder = f'mp_rank_{i:02d}' if pp_degree == 1 else f'mp_rank_{i:02d}_{j:03d}' + ckpt_path = os.path.join(rank_folder, 'model_optim_rng.pt') + path_list[i].append( + os.path.join(base_folder, iter_folder, ckpt_path)) + + return path_list + + +def _create_megatron_dict(): + language_model_dict = {EMBEDDING_KEY: {}, ENCODER_KEY: {}} + megatron_dict = { + MODEL_KEY: { + LANGUGAGE_MODEL_KEY: language_model_dict + }, + CHECKPOINT_VERSION_KEY: CHECKPOINT_VERSION_VALUE + } + return megatron_dict + + +def _save_checkpoint(file_path, chkpt_sd): + dir, _ = os.path.split(file_path) + os.makedirs(dir, exist_ok=True) + torch.save(chkpt_sd, file_path) + + + +def extract_zero_shards(dir, slice_shapes, ds_checkpoint, indices_3D): + pp_index, tp_index, dp_index = indices_3D + sd = ds_checkpoint.get_zero_checkpoint_state( + pp_index=pp_index, + tp_index=tp_index, + dp_index=dp_index) + + #pprint(f"Processing {dp_index=} {pp_index=}, {tp_index=}") + + optim_sd = sd["optimizer_state_dict"] + param_slice_mappings = optim_sd["param_slice_mappings"] + + # dict + state_groups = optim_sd["base_optimizer_state"]["state"] + # list + fp32_groups = optim_sd["single_partition_of_fp32_groups"] + param_groups_cnt = len(state_groups) + + for param_group_id in range(param_groups_cnt): + + flat_state = dict( + exp_avg=state_groups[param_group_id]["exp_avg"], + exp_avg_sq=state_groups[param_group_id]["exp_avg_sq"], + fp32=fp32_groups[param_group_id], + ) + + for name,fragment_mapping in param_slice_mappings[param_group_id].items(): + if "word_embeddings.weight" in name and pp_index > 0: + # Skip tied weights that are replicated in first and last pp stages + continue + + #print(f"{param_group_id} {name} => {fragment_mapping.start}:{fragment_mapping.numel}") + for state_key in flat_state.keys(): + dump_param_fragment(dir, tp_index, dp_index, state_key, flat_state[state_key], name, fragment_mapping.start, fragment_mapping.numel) + + + + +cnt = 0 +def dump_param_fragment(dir, tp_index, dp_index, state_name, state_flat_tensor, param_name, offset, numel): + + global cnt # temp hack + + param_base_path = os.path.join(dir, param_name, str(tp_index)) + os.makedirs(param_base_path, exist_ok=True) + + cnt += 1 + counter = f"{dp_index:0>2d}" + + path = os.path.join(param_base_path, f"{state_name}.{counter}") + + #print(f"{param_name}: {offset}: {numel} => {path}") + + t = state_flat_tensor.narrow(0, offset, numel) + _save_checkpoint(path, t) + + +def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape): + slices = [] + for tp_index in range(tp_degree): + prefix_path = os.path.join(param_base_path, str(tp_index), f"{state}") + paths = sorted(list(glob.glob(f"{prefix_path}.0*"))) + #print(paths) + shards = [torch.load(p) for p in paths] + slice = torch.cat(shards, dim=0).reshape(slice_shape) + slices.append(slice) + + return slices + + +ORIGINAL_VOCAB_SIZE = 'original_vocab_size' +def _strip_vocab_padding(ds_checkpoint, padded_vocab_tensor): + checkpoint_info = ds_checkpoint.get_checkpoint_info() + padding_tensor = padded_vocab_tensor.narrow(0, checkpoint_info[ORIGINAL_VOCAB_SIZE], padded_vocab_tensor.shape[0]-checkpoint_info[ORIGINAL_VOCAB_SIZE]) + #print(f'{padded_vocab_tensor[checkpoint_info[ORIGINAL_VOCAB_SIZE]-3:,:]=}') + return padded_vocab_tensor.narrow(0, 0, checkpoint_info[ORIGINAL_VOCAB_SIZE]) + + +WEIGHTS_TO_AVERAGE_PATTERNS = [ + r"tied_modules.embed.word_embeddings.norm.weight", + r"tied_modules.embed.word_embeddings.norm.bias", + r"\d+.input_layernorm.weight", + r"\d+.input_layernorm.bias", + r"\d+.post_attention_layernorm.weight", + r"\d+.post_attention_layernorm.bias", + r"\d+.self_attention.dense.bias", + r"\d+.mlp.dense_4h_to_h.bias", + r"\d+.weight", + r"\d+.bias", +] + +WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN = [ + "dense_4h_to_h.weight", + "self_attention.dense.weight", +] + + +def _get_vocab_divisibility_padding_tensor(ds_checkpoint, padded_vocab_tensor): + checkpoint_info = ds_checkpoint.get_checkpoint_info() + if padded_vocab_tensor.shape[0] > checkpoint_info[ORIGINAL_VOCAB_SIZE]: + return padded_vocab_tensor[-1] + else: + return torch.zeros(padded_vocab_tensor.shape[1]) + +def merge_tp_slices(ds_checkpoint, dir, slice_dir, tp_degree, name_and_shape): + name, shape = name_and_shape + slice_base_path = os.path.join(slice_dir, name) + param_base_path = os.path.join(dir, name) + + for state in ("fp32", "exp_avg", "exp_avg_sq"): + slices = _merge_zero_shards(slice_base_path, state, tp_degree, shape) + final_path = os.path.join(param_base_path, f"{state}.pt") + + #print(f"Expected shape: {shape}") + #print(f"Fragment sizes:", list(frag.shape for frag in slices)) + ckpt_dict = {} + if any(re.match(pattern, name) for pattern in WEIGHTS_TO_AVERAGE_PATTERNS): + param = sum(slices) / len(slices) + else: + cat_dim = 1 if any(text in name for text in WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN) else 0 + #print(f"CAT DIM: {cat_dim}") + param = torch.cat(slices, dim=cat_dim) + ckpt_dict['cat_dim'] = cat_dim + + if "word_embeddings.weight" in name: + #print(f"Before {param.shape=}") + # strip padding + #param = _strip_vocab_padding(ds_checkpoint, param) + ckpt_dict['vocab_divisibility_padding_tensor'] = _get_vocab_divisibility_padding_tensor(ds_checkpoint, param) + #print(f"After {param.shape=}") + + #print(f"Final shape: {param.shape}") + ckpt_dict['param'] = param + _save_checkpoint(final_path, ckpt_dict) + + + + + + +def _get_chunks(l, n): + for i in range(0, len(l), n): + yield l[i:i + n] + + +def _do_parallel_work(do_work, work_chunks, num_workers): + pool = multiprocessing.Pool(num_workers) + for batch in tqdm.tqdm(work_chunks): + pool.map(do_work, batch) + pool.close() + pool.join() + +def _extract_zero_shard_files(args, ds_checkpoint, slice_shapes, temp_dir): + _3d_range_list = list(itertools.product(range(ds_checkpoint.pp_degree), range(ds_checkpoint.tp_degree), range(ds_checkpoint.dp_degree))) + #pprint(_3d_range_list) + work_chunks = list(_get_chunks(_3d_range_list, args.num_extract_workers)) + #pprint(work_chunks) + + do_work = partial(extract_zero_shards, temp_dir, slice_shapes, ds_checkpoint) + _do_parallel_work(do_work, work_chunks, args.num_extract_workers) + + + +def _merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir): + work_chunks = list(_get_chunks(list(slice_shapes.items()), args.num_merge_workers)) + #pprint(work_chunks) + zero_output_folder = os.path.join(args.output_folder, "zero") + do_work = partial(merge_tp_slices, ds_checkpoint, zero_output_folder, temp_dir, ds_checkpoint.tp_degree) + _do_parallel_work(do_work, work_chunks, args.num_merge_workers) + + + +def main(): + print(f'Convert DeepSpeed Checkpoint to Universal Checkpoint') + + args = parse_arguments() + print( + f'Converting DeepSpeed checkpoint in {args.input_folder} to Universal checkpoint in {args.output_folder}' + ) + + ds_checkpoint = DeepSpeedCheckpoint(args.input_folder)#, 1, 2) # args.target_tp, args.target_pp) + + iteration = ds_checkpoint.get_iteration() + #_create_latest_file(args.output_folder, iteration) + checkpoint_paths = _create_checkpoint_paths(args.output_folder, iteration, + ds_checkpoint.tp_degree, + ds_checkpoint.pp_degree) + + slice_shapes = [] + for mp_rank_file in ds_checkpoint.mp_rank_files: + mp_sd = torch.load(mp_rank_file, map_location=torch.device('cpu')) + slice_shapes += mp_sd["param_shapes"] + + # fix back to normal flat dict, merge duplicates for tp>1 + slice_shapes = dict((k,v) for d in slice_shapes for k,v in d.items() ) + temp_dir = os.path.join(args.output_folder, 'tmp') + + print('*** 1. Extracting ZeRO fragments') + _extract_zero_shard_files(args, ds_checkpoint, slice_shapes, temp_dir) + + print('*** 2. Merging slices') + _merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir) + + shutil.rmtree(temp_dir, ignore_errors=True) + + # Copy mp* files into output folder + for f in glob.glob(os.path.join(args.input_folder, 'mp*')): + shutil.copy2(f, args.output_folder) + + # Update latest to output folder + checkpoint_root_folder, step_folder = os.path.split(args.output_folder) + latest_file = os.path.join(checkpoint_root_folder, 'latest_universal') + with open(latest_file, "w") as f: + f.write(step_folder) + + print('*** Done!') + + +if __name__ == "__main__": + main() diff --git a/tools/convert_checkpoint/inspect_checkpoint.py b/tools/convert_checkpoint/inspect_checkpoint.py index 5ee955bb4..bee217f5c 100644 --- a/tools/convert_checkpoint/inspect_checkpoint.py +++ b/tools/convert_checkpoint/inspect_checkpoint.py @@ -1,13 +1,19 @@ -import torch import sys +import torch import os from collections import OrderedDict +from pathlib import Path + +# insert megatron's root dir into sys.path +root_repo_path = str(Path(__file__).resolve().parents[2]) +if root_repo_path not in sys.path: + sys.path.insert(0, root_repo_path) def dump_data(datum, name_list=[]): if type(datum) in (dict, OrderedDict): for k, v in datum.items(): - dump_data(v, name_list+[str(k)]) + dump_data(v, name_list + [str(k)]) elif type(datum) in (list, tuple): for v in datum: dump_data(v, name_list) @@ -15,10 +21,11 @@ def dump_data(datum, name_list=[]): prefix = '.'.join(name_list) print(f'[tensor] {prefix} = {datum.shape}') else: - #pass + #pass prefix = '.'.join(name_list) print(f'[other] {prefix} = {datum}') + def main(): if len(sys.argv) < 2: print(f'Usage: {sys.argv[0]} ') @@ -30,7 +37,7 @@ def main(): exit(1) print(f'loading checkpoint file: {ckpt_file}') - sd = torch.load(ckpt_file) + sd = torch.load(ckpt_file, map_location=torch.device('cpu')) dump_data(sd) quit() diff --git a/tools/convert_checkpoint/inspect_deepspeed_checkpoint.py b/tools/convert_checkpoint/inspect_deepspeed_checkpoint.py index 3125f7d9a..09fa60991 100644 --- a/tools/convert_checkpoint/inspect_deepspeed_checkpoint.py +++ b/tools/convert_checkpoint/inspect_deepspeed_checkpoint.py @@ -1,19 +1,39 @@ +import sys +from pathlib import Path + +# insert megatron's root dir into sys.path +root_repo_path = str(Path(__file__).resolve().parents[2]) +if root_repo_path not in sys.path: + sys.path.insert(0, root_repo_path) + import argparse -from deepspeed_checkpoint import DeepSpeedCheckpoint + +from deepspeed.checkpoint import DeepSpeedCheckpoint + def list_files(file_list, tag): print(f'Listing files: {tag}') for i, file in enumerate(file_list): print(f'{i+1}: {file}') + def parse_arguments(): parser = argparse.ArgumentParser() - parser.add_argument('--folder', default=None, type=str, help='DeepSpeed Checkpoint folder') - parser.add_argument('--target_tp', default=None, type=int, help='Target TP degree') - parser.add_argument('--target_pp', default=None, type=int, help='Target PP degree') + parser.add_argument('--folder', + default=None, + type=str, + help='DeepSpeed Checkpoint folder') + parser.add_argument('--target_tp', + default=None, + type=int, + help='Target TP degree') + parser.add_argument('--target_pp', + default=None, + type=int, + help='Target PP degree') args = parser.parse_args() print(f'args = {args}') - return args + return args def show_input_files(ds_checkpoint): @@ -22,38 +42,52 @@ def show_input_files(ds_checkpoint): list_files(ds_checkpoint.layer_files, 'layer') list_files(ds_checkpoint.mp_rank_files, 'mp rank') + def show_simple_state(ds_checkpoint): print(f'layer keys = {ds_checkpoint.layer_keys}') print(f'layer count = {ds_checkpoint.layer_count}') - print(f'tp_degree_count = {ds_checkpoint.tp_degree}') - print(f'pp_degree_count = {ds_checkpoint.pp_degree}') + print( + f'tp_degree_count = {ds_checkpoint.original_tp_degree} ------> {ds_checkpoint.tp_degree}' + ) + print( + f'pp_degree_count = {ds_checkpoint.original_pp_degree} ------> {ds_checkpoint.pp_degree}' + ) print(f'dp_degree_count = {ds_checkpoint.dp_degree}') + ds_checkpoint.old_2d_map.print_data('old 2d map ==>') + ds_checkpoint.new_2d_map.print_data('new 2d map ==>') + def show_mappings(ds_checkpoint): ds_checkpoint.show_pp_tranformer_map() ds_checkpoint.show_transformer_file_map() ds_checkpoint.show_tp_embedding_map() ds_checkpoint.show_tp_final_norm_map() + ds_checkpoint.show_2d_mapping() + def show_state_summary(tag, sd): - summary = {k:v.shape for k,v in sd.items()} + summary = {k: v.shape for k, v in sd.items()} print(f'{tag} = {summary}') + def show_embedding_states(ds_checkpoint): for i in range(0, ds_checkpoint.tp_degree): sd = ds_checkpoint.get_embedding_state(i) show_state_summary(f'embedding[{i}]', sd) + def show_final_norm_states(ds_checkpoint): for i in range(0, ds_checkpoint.tp_degree): sd = ds_checkpoint.get_final_norm_state(i) show_state_summary(f'final_norm[{i}]', sd) + def show_transformer_states(ds_checkpoint): for i in range(0, ds_checkpoint.tp_degree): for j in range(0, ds_checkpoint.pp_degree): - state_list = ds_checkpoint.get_transformer_state(tp_index=i, pp_index=j) + state_list = ds_checkpoint.get_transformer_state(tp_index=i, + pp_index=j) print(f'tp_pp_rank[{i},{j}] = ') for k, sd in enumerate(state_list): show_state_summary(f' block[{k}]', sd) @@ -64,9 +98,11 @@ def main(): print(f'Inspecting DeepSpeed Checkpoint') args = parse_arguments() - ds_checkpoint = DeepSpeedCheckpoint(args.folder, args.target_tp, args.target_pp) + ds_checkpoint = DeepSpeedCheckpoint(args.folder, args.target_tp, + args.target_pp) ds_checkpoint.validate_files() - + + show_simple_state(ds_checkpoint) show_input_files(ds_checkpoint) show_simple_state(ds_checkpoint) show_mappings(ds_checkpoint) @@ -76,5 +112,6 @@ def main(): checkpoint_args = ds_checkpoint.get_args() print(f'checkpoint args = {checkpoint_args}') + if __name__ == "__main__": main()