diff --git a/deepspeed/checkpoint/hf_to_universal.py b/deepspeed/checkpoint/hf_to_universal.py new file mode 100644 index 000000000000..e68be7d8780e --- /dev/null +++ b/deepspeed/checkpoint/hf_to_universal.py @@ -0,0 +1,225 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import os +import shutil +import logging +from concurrent.futures import ProcessPoolExecutor +from deepspeed.accelerator import get_accelerator +from tqdm import tqdm +from typing import List + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Hard-coded constants for parameter patterns +VOCAB_PARAMETER_PATTERNS = [ + 'word_embeddings', + 'embed_tokens', + 'embedding', + 'wte', # GPT style embeddings + 'lm_head' # Language model head, often tied with embeddings +] + + +def get_parameter_type(name: str) -> dict: + """Determine parameter type and required fields based on name.""" + param_info = { + 'cat_dim': 0 # Default concatenation dimension + } + + # Check for vocabulary tensors (embeddings, etc.) + if any(pattern in name.lower() for pattern in VOCAB_PARAMETER_PATTERNS): + param_info['vocab_tensor'] = True + + # TODO: figure out if we need to check for row-parallel parameters + return param_info + + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser(description='Convert HuggingFace checkpoint to Universal Checkpoint format') + parser.add_argument('--hf_checkpoint_dir', + type=str, + required=True, + help='Path to the HuggingFace checkpoint directory') + parser.add_argument('--safe_serialization', + action='store_true', + default=False, + help='Use safetensors for serialization') + parser.add_argument('--num_workers', type=int, default=4, help='Number of workers to use for saving checkpoints') + parser.add_argument('--save_dir', type=str, required=True, help='Directory to save checkpoints') + args = parser.parse_args() + + # Create a temporary directory for atomic operations + temp_save_dir = args.save_dir + '.tmp' + + def save_parameter(name: str, param: torch.Tensor, save_dir: str): + """Save a parameter and its optimizer states in universal format.""" + # Create parameter directory under zero/ + param_dir = os.path.join(save_dir, name) + os.makedirs(param_dir, exist_ok=True) + + # Get parameter type and required fields + param_info = get_parameter_type(name) + + # Save parameter in fp32 with proper dictionary structure + param_path = os.path.join(param_dir, "fp32.pt") + param_dict = { + 'param': param.to(torch.float32), # Main tensor goes in 'param' field + **param_info # Include all determined parameter info + } + torch.save(param_dict, param_path) + + # Since HuggingFace checkpoints do not have optimizer states, + # we initialize them with zeros + for state in ("exp_avg", "exp_avg_sq"): + state_path = os.path.join(param_dir, f"{state}.pt") + state_dict = { + 'param': torch.zeros_like(param, dtype=torch.float32), + **param_info # Include same parameter info in optimizer states + } + torch.save(state_dict, state_path) + + def process_shard(shard_file, checkpoint_dir, save_dir, safe_serialization): + """Process a single shard file.""" + try: + shard_path = os.path.join(checkpoint_dir, shard_file) + logger.info(f"Loading shard from: {shard_path}") + + if safe_serialization: + from safetensors.torch import load_file + shard_dict = load_file(shard_path) + else: + shard_dict = torch.load(shard_path, map_location='cpu') + + # Create progress bar for parameters within this shard + pbar = tqdm(total=len(shard_dict), + desc=f"Processing {os.path.basename(shard_file)}", + position=1, + leave=False) + + for key, param in shard_dict.items(): + save_parameter(key, param, save_dir) + del param + pbar.update(1) + pbar.set_postfix({'key': key[:20] + '...' if len(key) > 20 else key}) + + pbar.close() + del shard_dict + get_accelerator().empty_cache() + logger.info(f"Completed processing shard: {shard_file}") + + except Exception as e: + logger.error(f"Error processing shard {shard_file}: {str(e)}") + raise + + def get_shard_list(checkpoint_dir): + """Get list of shards from index file.""" + if args.safe_serialization: + index_file = os.path.join(checkpoint_dir, "model.safetensors.index.json") + else: + index_file = os.path.join(checkpoint_dir, "pytorch_model.bin.index.json") + + if os.path.exists(index_file): + import json + with open(index_file, 'r') as f: + index = json.load(f) + return list(set(index['weight_map'].values())) + else: + # Handle single file case + if args.safe_serialization and os.path.exists(os.path.join(checkpoint_dir, "model.safetensors")): + return ["model.safetensors"] + elif os.path.exists(os.path.join(checkpoint_dir, "pytorch_model.bin")): + return ["pytorch_model.bin"] + else: + raise FileNotFoundError(f"No checkpoint files found in {checkpoint_dir}") + + def process_shard_batch(shard_files: List[str], checkpoint_dir: str, save_dir: str, safe_serialization: bool): + """Process a batch of shards in parallel.""" + with ProcessPoolExecutor(max_workers=args.num_workers) as executor: + futures = [] + for shard_file in shard_files: + future = executor.submit(process_shard, shard_file, checkpoint_dir, save_dir, safe_serialization) + futures.append((shard_file, future)) + + # Create progress bar for this batch + batch_pbar = tqdm(total=len(futures), desc=f"Processing shard batch", position=0, leave=True) + + # Wait for all futures to complete + for shard_file, future in futures: + try: + future.result() # This will raise any exceptions that occurred + batch_pbar.update(1) + batch_pbar.set_postfix({'last_completed': os.path.basename(shard_file)}) + except Exception as e: + logger.error(f"Failed processing shard {shard_file}: {str(e)}") + raise + + batch_pbar.close() + + try: + # Create zero subdirectory in temp directory + temp_zero_dir = os.path.join(temp_save_dir, 'zero') + if os.path.exists(temp_zero_dir): + logger.info(f"Removing existing temp directory: {temp_zero_dir}") + shutil.rmtree(temp_zero_dir) + + shard_files = get_shard_list(args.hf_checkpoint_dir) + total_shards = len(shard_files) + logger.info(f"Found {total_shards} shards to process") + # Process shards in batches equal to the number of workers + batch_size = args.num_workers + for i in range(0, total_shards, batch_size): + batch_shards = shard_files[i:i + batch_size] + logger.info( + f"Processing batch of {len(batch_shards)} shards ({i+1}-{min(i+batch_size, total_shards)} of {total_shards})" + ) + process_shard_batch( + batch_shards, + args.hf_checkpoint_dir, + temp_zero_dir, # Changed from temp_save_dir to temp_zero_dir + args.safe_serialization) + + # Clear CUDA cache after each batch to free up memory + get_accelerator().empty_cache() + + logger.info("All shard batches processed successfully") + + final_save_dir = os.path.join(args.save_dir, 'zero') + if os.path.exists(final_save_dir): + shutil.rmtree(final_save_dir) + + # Create the parent directory if it doesn't exist + os.makedirs(os.path.dirname(final_save_dir), exist_ok=True) + # Move the zero directory to its final location + os.rename(temp_zero_dir, final_save_dir) + + # Clean up the temporary directory + if os.path.exists(temp_save_dir): + shutil.rmtree(temp_save_dir) + + # Write identifier file + with open(os.path.join(args.save_dir, 'source.txt'), 'w') as f: + f.write("Huggingface checkpoint") + + logger.info(f"Successfully saved checkpoint to {final_save_dir}") + + # Update latest file + checkpoint_root_folder = os.path.dirname(args.save_dir) + step_folder = os.path.basename(args.save_dir) + latest_file = os.path.join(checkpoint_root_folder, 'latest_universal') + with open(latest_file, 'w') as f: + f.write(step_folder) + + logger.info(f"Checkpoint conversion completed successfully. Latest file updated at {latest_file}") + + except Exception as e: + logger.error(f"Failed to process checkpoint: {str(e)}") + if os.path.exists(temp_save_dir): + shutil.rmtree(temp_save_dir) + raise diff --git a/deepspeed/runtime/base_optimizer.py b/deepspeed/runtime/base_optimizer.py index 726124027131..f3366c06e5a6 100644 --- a/deepspeed/runtime/base_optimizer.py +++ b/deepspeed/runtime/base_optimizer.py @@ -21,11 +21,14 @@ class ZeROOptimizer(DeepSpeedOptimizer): def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, checkpoint_dir: str) -> None: checkpoint_dir = os.path.join(checkpoint_dir, "zero") optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt") - assert os.path.isfile( - optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.' - optim_sd = torch.load(optim_state_path, weights_only=False) - - self._load_global_state(optim_sd) + if os.path.isfile(optim_state_path): + ignore_missing_optim_state = False + optim_sd = torch.load(optim_state_path, weights_only=False) + self._load_global_state(optim_sd) + else: + logger.warning(f'{optim_state_path} containing optimizer global state is missing!') + ignore_missing_optim_state = True + optim_sd = {} tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu) if self.mpu is None: @@ -35,8 +38,7 @@ def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, chec tp_world_size = self.mpu.get_slice_parallel_world_size() if hasattr(self.mpu, "get_slice_parallel_world_size") \ else self.mpu.get_tensor_model_parallel_world_size() - for i, (param_group, - loaded_param_group) in enumerate(zip(self.optimizer.param_groups, optim_sd['param_groups'])): + for i, param_group in enumerate(self.optimizer.param_groups): # We have an assumption that all params in the same param_group have the same keys opt_keys = set() steps = [] @@ -58,6 +60,9 @@ def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, chec map_to_flat_opt_states(hp_param, lp_groups[i], self.optimizer.state, opt_keys) + if ignore_missing_optim_state: + continue + loaded_param_group = optim_sd['param_groups'][i] for key, value in loaded_param_group.items(): if key == 'params': continue diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 76b83a716ffe..c8798913d249 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2977,7 +2977,7 @@ def _get_all_ckpt_names(self, checkpoints_path, tag): ckpt_files = glob.glob(ckpt_file_pattern) ckpt_files.sort() - return ckpt_files + return ckpt_files, ckpt_file_pattern def load_checkpoint(self, load_dir, @@ -3001,7 +3001,7 @@ def load_checkpoint(self, Returns: A tuple of ``load_path`` and ``client_state``. - *``load_path``: Path of the loaded checkpoint. ``None`` if loading the checkpoint failed. + *``load_path``: Path of the loaded checkpoint. ``None`` if loading the checkpoint failed or loading a HF based UCP. *``client_state``: State dictionary used for loading required training states in the client code. Important: under ZeRO3, one cannot load checkpoint with ``engine.load_checkpoint()`` right @@ -3040,6 +3040,11 @@ def load_checkpoint(self, custom_load_fn=custom_load_fn) load_zero_checkpoint = load_path is not None and (self.zero_optimization() or self.bfloat16_enabled()) + if self.load_universal_checkpoint(): + ucp_ckpt_folder = os.path.join(load_dir, tag) + # UCP load can ignore '*mp' files or '*model_states.pt' but ucp_ckpt_folder must exist + load_zero_checkpoint = os.path.isdir(ucp_ckpt_folder) + if load_zero_checkpoint: if (load_optimizer_states and not load_module_only) or self.load_universal_checkpoint(): success = self._load_zero_checkpoint(load_dir, tag, load_optimizer_states=load_optimizer_states) @@ -3080,7 +3085,11 @@ def _load_checkpoint(self, from deepspeed.runtime.state_dict_factory import SDLoaderFactory - ckpt_list = self._get_all_ckpt_names(load_dir, tag) + ckpt_list, ckpt_file_pattern = self._get_all_ckpt_names(load_dir, tag) + if self.load_universal_checkpoint() and len(ckpt_list) == 0: + logger.warning(f"Unable to find {ckpt_file_pattern} files in UCP folder {load_dir}") + return None, {} + sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list, checkpoint_engine=self.checkpoint_engine) is_pipe_parallel = isinstance(self.module, PipelineModule) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 8114bdd050dd..f56221d694e1 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -2865,11 +2865,13 @@ def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, pa """ Load optimizer and model states from the checkpoint directory. """ checkpoint_dir = os.path.join(checkpoint_dir, "zero") optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt") - assert os.path.isfile( - optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.' - - optim_sd = torch.load(optim_state_path, weights_only=False) - self._load_global_state_stage3(optim_sd) + if os.path.isfile(optim_state_path): + ignore_missing_optim_state = False + optim_sd = torch.load(optim_state_path, weights_only=False) + self._load_global_state_stage3(optim_sd) + else: + logger.warning(f'{optim_state_path} containing optimizer global state is missing!') + ignore_missing_optim_state = True key_list = ["fp32", "exp_avg", "exp_avg_sq"] @@ -2881,14 +2883,13 @@ def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, pa if key == "fp32": self.fp32_partitioned_groups_flat[0].data.copy_(key_tensor) self.optimizer.param_groups[0]['params'].append(self.fp32_partitioned_groups_flat[0]) - else: + elif not ignore_missing_optim_state: optim_sd[OPTIMIZER_STATE_DICT]['state'][0][key] = key_tensor if self.swap_optimizer: # Purge the swapped optimizer state, it was initialized to the freshly created model and not the checkpoint self.optimizer_swapper.purge_state() - if self.swap_optimizer: # Touch all parameters to synchronize all buffers timer_names = set() self._partition_all_parameters() @@ -2898,9 +2899,10 @@ def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, pa self._release_sub_group(sub_group_id, timer_names) self._post_step(timer_names) - self.optimizer.load_state_dict(optim_sd[OPTIMIZER_STATE_DICT]) - for param_group in self.optimizer.param_groups: - param_group['params'] = [] + if not ignore_missing_optim_state: + self.optimizer.load_state_dict(optim_sd[OPTIMIZER_STATE_DICT]) + for param_group in self.optimizer.param_groups: + param_group['params'] = [] for sub_group_id in range(len(self.fp32_partitioned_groups_flat)): fp32_param = self.fp32_partitioned_groups_flat[sub_group_id] @@ -2924,7 +2926,13 @@ def load_hp_checkpoint_state(self, folder, key): local_rank = dist.get_local_rank() # Load tensors from files and reshape them to flat vectors - loaded_checkpoint_state = torch.load(os.path.join(folder, f"{key}.pt"), weights_only=False).view(-1) + loaded_state = torch.load(os.path.join(folder, f"{key}.pt"), weights_only=False) + if isinstance(loaded_state, dict): + loaded_checkpoint_state = loaded_state['param'].view(-1) + elif isinstance(loaded_state, torch.Tensor): + loaded_checkpoint_state = loaded_state.view(-1) + else: + raise ValueError(f"Unknown type {type(loaded_state)} for loaded state") # Partition the loaded data according to the local rank world_size = dist.get_world_size(group=self.dp_process_group)