- 
                Notifications
    You must be signed in to change notification settings 
- Fork 4.6k
          HF2UCP: Converting a pytorch_model.bin or .safetensors checkpoint to UCP
          #7212
        
          New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
          
     Open
      
      
            Schwidola0607
  wants to merge
  15
  commits into
  deepspeedai:master
  
    
      
        
          
  
    
      Choose a base branch
      
     
    
      
        
      
      
        
          
          
        
        
          
            
              
              
              
  
           
        
        
          
            
              
              
           
        
       
     
  
        
          
            
          
            
          
        
       
    
      
from
Schwidola0607:master
  
      
      
   
  
    
  
  
  
 
  
      
    base: master
Could not load branches
            
              
  
    Branch not found: {{ refName }}
  
            
                
      Loading
              
            Could not load tags
            
            
              Nothing to show
            
              
  
            
                
      Loading
              
            Are you sure you want to change the base?
            Some commits from the old base branch may be removed from the timeline,
            and old review comments may become outdated.
          
          
  
     Open
                    Changes from 14 commits
      Commits
    
    
            Show all changes
          
          
            15 commits
          
        
        Select commit
          Hold shift + click to select a range
      
      0e1ea4c
              
                add support for HF2UCP feature
              
              
                Schwidola0607 727206e
              
                add user guide
              
              
                Schwidola0607 49588a8
              
                edit user guide
              
              
                Schwidola0607 7bef517
              
                cleaning up
              
              
                Schwidola0607 2930f2a
              
                nits
              
              
                xylian86 9207df9
              
                remove ignore_missing_optim config from zero ds_config
              
              
                Schwidola0607 8090369
              
                fix to make ucp load more lenient
              
              
                Schwidola0607 7b8962a
              
                nits
              
              
                xylian86 2389567
              
                nits
              
              
                xylian86 f34c6df
              
                Merge branch 'master' into master
              
              
                loadams 2fa0889
              
                formatting and license
              
              
                Schwidola0607 724a480
              
                Minor comment fix
              
              
                Schwidola0607 64f5bc3
              
                remove document
              
              
                Schwidola0607 4dbd67f
              
                Merge branch 'master' into master
              
              
                tjruwase ea01489
              
                Merge branch 'master' into master
              
              
                tjruwase File filter
Filter by extension
Conversations
          Failed to load comments.   
        
        
          
      Loading
        
  Jump to
        
          Jump to file
        
      
      
          Failed to load files.   
        
        
          
      Loading
        
  Diff view
Diff view
There are no files selected for viewing
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,225 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|  | ||
| # DeepSpeed Team | ||
|  | ||
| import torch | ||
| import os | ||
| import shutil | ||
| import logging | ||
| from concurrent.futures import ProcessPoolExecutor | ||
| from deepspeed.accelerator import get_accelerator | ||
| from tqdm import tqdm | ||
| from typing import List | ||
|  | ||
| logging.basicConfig(level=logging.INFO) | ||
| logger = logging.getLogger(__name__) | ||
|  | ||
| # Hard-coded constants for parameter patterns | ||
| VOCAB_PARAMETER_PATTERNS = [ | ||
| 'word_embeddings', | ||
| 'embed_tokens', | ||
| 'embedding', | ||
| 'wte', # GPT style embeddings | ||
| 'lm_head' # Language model head, often tied with embeddings | ||
| ] | ||
|  | ||
|  | ||
| def get_parameter_type(name: str) -> dict: | ||
| """Determine parameter type and required fields based on name.""" | ||
| param_info = { | ||
| 'cat_dim': 0 # Default concatenation dimension | ||
| } | ||
|  | ||
| # Check for vocabulary tensors (embeddings, etc.) | ||
| if any(pattern in name.lower() for pattern in VOCAB_PARAMETER_PATTERNS): | ||
| param_info['vocab_tensor'] = True | ||
|  | ||
| # TODO: figure out if we need to check for row-parallel parameters | ||
| return param_info | ||
|  | ||
|  | ||
| if __name__ == '__main__': | ||
| import argparse | ||
|  | ||
| parser = argparse.ArgumentParser(description='Convert HuggingFace checkpoint to Universal Checkpoint format') | ||
| parser.add_argument('--hf_checkpoint_dir', | ||
| type=str, | ||
| required=True, | ||
| help='Path to the HuggingFace checkpoint directory') | ||
| parser.add_argument('--safe_serialization', | ||
| action='store_true', | ||
| default=False, | ||
| help='Use safetensors for serialization') | ||
| parser.add_argument('--num_workers', type=int, default=4, help='Number of workers to use for saving checkpoints') | ||
| parser.add_argument('--save_dir', type=str, required=True, help='Directory to save checkpoints') | ||
| args = parser.parse_args() | ||
|  | ||
| # Create a temporary directory for atomic operations | ||
| temp_save_dir = args.save_dir + '.tmp' | ||
|  | ||
| def save_parameter(name: str, param: torch.Tensor, save_dir: str): | ||
| """Save a parameter and its optimizer states in universal format.""" | ||
| # Create parameter directory under zero/ | ||
| param_dir = os.path.join(save_dir, name) | ||
| os.makedirs(param_dir, exist_ok=True) | ||
|  | ||
| # Get parameter type and required fields | ||
| param_info = get_parameter_type(name) | ||
|  | ||
| # Save parameter in fp32 with proper dictionary structure | ||
| param_path = os.path.join(param_dir, "fp32.pt") | ||
| param_dict = { | ||
| 'param': param.to(torch.float32), # Main tensor goes in 'param' field | ||
| **param_info # Include all determined parameter info | ||
| } | ||
| torch.save(param_dict, param_path) | ||
|  | ||
| # Since HuggingFace checkpoints do not have optimizer states, | ||
| # we initialize them with zeros | ||
| for state in ("exp_avg", "exp_avg_sq"): | ||
| state_path = os.path.join(param_dir, f"{state}.pt") | ||
| state_dict = { | ||
| 'param': torch.zeros_like(param, dtype=torch.float32), | ||
| **param_info # Include same parameter info in optimizer states | ||
| } | ||
| torch.save(state_dict, state_path) | ||
|  | ||
| def process_shard(shard_file, checkpoint_dir, save_dir, safe_serialization): | ||
| """Process a single shard file.""" | ||
| try: | ||
| shard_path = os.path.join(checkpoint_dir, shard_file) | ||
| logger.info(f"Loading shard from: {shard_path}") | ||
|  | ||
| if safe_serialization: | ||
| from safetensors.torch import load_file | ||
| shard_dict = load_file(shard_path) | ||
| else: | ||
| shard_dict = torch.load(shard_path, map_location='cpu') | ||
|  | ||
| # Create progress bar for parameters within this shard | ||
| pbar = tqdm(total=len(shard_dict), | ||
| desc=f"Processing {os.path.basename(shard_file)}", | ||
| position=1, | ||
| leave=False) | ||
|  | ||
| for key, param in shard_dict.items(): | ||
| save_parameter(key, param, save_dir) | ||
| del param | ||
| pbar.update(1) | ||
| pbar.set_postfix({'key': key[:20] + '...' if len(key) > 20 else key}) | ||
|  | ||
| pbar.close() | ||
| del shard_dict | ||
| get_accelerator().empty_cache() | ||
| logger.info(f"Completed processing shard: {shard_file}") | ||
|  | ||
| except Exception as e: | ||
| logger.error(f"Error processing shard {shard_file}: {str(e)}") | ||
| raise | ||
|  | ||
| def get_shard_list(checkpoint_dir): | ||
| """Get list of shards from index file.""" | ||
| if args.safe_serialization: | ||
| index_file = os.path.join(checkpoint_dir, "model.safetensors.index.json") | ||
| else: | ||
| index_file = os.path.join(checkpoint_dir, "pytorch_model.bin.index.json") | ||
|  | ||
| if os.path.exists(index_file): | ||
| import json | ||
| with open(index_file, 'r') as f: | ||
| index = json.load(f) | ||
| return list(set(index['weight_map'].values())) | ||
| else: | ||
| # Handle single file case | ||
| if args.safe_serialization and os.path.exists(os.path.join(checkpoint_dir, "model.safetensors")): | ||
| return ["model.safetensors"] | ||
| elif os.path.exists(os.path.join(checkpoint_dir, "pytorch_model.bin")): | ||
| return ["pytorch_model.bin"] | ||
| else: | ||
| raise FileNotFoundError(f"No checkpoint files found in {checkpoint_dir}") | ||
|  | ||
| def process_shard_batch(shard_files: List[str], checkpoint_dir: str, save_dir: str, safe_serialization: bool): | ||
| """Process a batch of shards in parallel.""" | ||
| with ProcessPoolExecutor(max_workers=args.num_workers) as executor: | ||
| futures = [] | ||
| for shard_file in shard_files: | ||
| future = executor.submit(process_shard, shard_file, checkpoint_dir, save_dir, safe_serialization) | ||
| futures.append((shard_file, future)) | ||
|  | ||
| # Create progress bar for this batch | ||
| batch_pbar = tqdm(total=len(futures), desc=f"Processing shard batch", position=0, leave=True) | ||
|  | ||
| # Wait for all futures to complete | ||
| for shard_file, future in futures: | ||
| try: | ||
| future.result() # This will raise any exceptions that occurred | ||
| batch_pbar.update(1) | ||
| batch_pbar.set_postfix({'last_completed': os.path.basename(shard_file)}) | ||
| except Exception as e: | ||
| logger.error(f"Failed processing shard {shard_file}: {str(e)}") | ||
| raise | ||
|  | ||
| batch_pbar.close() | ||
|  | ||
| try: | ||
| # Create zero subdirectory in temp directory | ||
| temp_zero_dir = os.path.join(temp_save_dir, 'zero') | ||
| if os.path.exists(temp_zero_dir): | ||
| logger.info(f"Removing existing temp directory: {temp_zero_dir}") | ||
| shutil.rmtree(temp_zero_dir) | ||
|  | ||
| shard_files = get_shard_list(args.hf_checkpoint_dir) | ||
| total_shards = len(shard_files) | ||
| logger.info(f"Found {total_shards} shards to process") | ||
| # Process shards in batches equal to the number of workers | ||
| batch_size = args.num_workers | ||
| for i in range(0, total_shards, batch_size): | ||
| batch_shards = shard_files[i:i + batch_size] | ||
| logger.info( | ||
| f"Processing batch of {len(batch_shards)} shards ({i+1}-{min(i+batch_size, total_shards)} of {total_shards})" | ||
| ) | ||
| process_shard_batch( | ||
| batch_shards, | ||
| args.hf_checkpoint_dir, | ||
| temp_zero_dir, # Changed from temp_save_dir to temp_zero_dir | ||
| args.safe_serialization) | ||
|  | ||
| # Clear CUDA cache after each batch to free up memory | ||
| get_accelerator().empty_cache() | ||
|  | ||
| logger.info("All shard batches processed successfully") | ||
|  | ||
| final_save_dir = os.path.join(args.save_dir, 'zero') | ||
| if os.path.exists(final_save_dir): | ||
| shutil.rmtree(final_save_dir) | ||
|  | ||
| # Create the parent directory if it doesn't exist | ||
| os.makedirs(os.path.dirname(final_save_dir), exist_ok=True) | ||
| # Move the zero directory to its final location | ||
| os.rename(temp_zero_dir, final_save_dir) | ||
|  | ||
| # Clean up the temporary directory | ||
| if os.path.exists(temp_save_dir): | ||
| shutil.rmtree(temp_save_dir) | ||
|  | ||
| # Write identifier file | ||
| with open(os.path.join(args.save_dir, 'source.txt'), 'w') as f: | ||
| f.write("Huggingface checkpoint") | ||
|  | ||
| logger.info(f"Successfully saved checkpoint to {final_save_dir}") | ||
|  | ||
| # Update latest file | ||
| checkpoint_root_folder = os.path.dirname(args.save_dir) | ||
| step_folder = os.path.basename(args.save_dir) | ||
| latest_file = os.path.join(checkpoint_root_folder, 'latest_universal') | ||
| with open(latest_file, 'w') as f: | ||
| f.write(step_folder) | ||
|  | ||
| logger.info(f"Checkpoint conversion completed successfully. Latest file updated at {latest_file}") | ||
|  | ||
| except Exception as e: | ||
| logger.error(f"Failed to process checkpoint: {str(e)}") | ||
| if os.path.exists(temp_save_dir): | ||
| shutil.rmtree(temp_save_dir) | ||
| raise | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  Add this suggestion to a batch that can be applied as a single commit.
  This suggestion is invalid because no changes were made to the code.
  Suggestions cannot be applied while the pull request is closed.
  Suggestions cannot be applied while viewing a subset of changes.
  Only one suggestion per line can be applied in a batch.
  Add this suggestion to a batch that can be applied as a single commit.
  Applying suggestions on deleted lines is not supported.
  You must change the existing code in this line in order to create a valid suggestion.
  Outdated suggestions cannot be applied.
  This suggestion has been applied or marked resolved.
  Suggestions cannot be applied from pending reviews.
  Suggestions cannot be applied on multi-line comments.
  Suggestions cannot be applied while the pull request is queued to merge.
  Suggestion cannot be applied right now. Please check back later.
  
    
  
    
Uh oh!
There was an error while loading. Please reload this page.