diff --git a/MFU_README.md b/MFU_README.md new file mode 100644 index 0000000000..8491ce98e9 --- /dev/null +++ b/MFU_README.md @@ -0,0 +1,153 @@ +# MFU (Model FLOPs Utilization) Support for grpo_fast + +This document describes the MFU (Model FLOPs Utilization) functionality that has been added to the `grpo_fast.py` training script. + +## Overview + +MFU is a metric that measures how efficiently a model is utilizing the available computational resources. It is calculated as: + +``` +MFU = (Actual FLOPs per second) / (Theoretical Peak FLOPs per second) × 100% +``` + +## Features Added + +### 1. FLOPs Calculation +- **Function**: `calculate_model_flops_per_token(model, tokenizer)` +- **Purpose**: Calculates the exact FLOPs for a single token forward pass using torch's flops counter +- **Dependencies**: Requires `fvcore` library for FLOPs counting + +### 2. MFU Calculation +- **Function**: `calculate_mfu(flops_per_token, tokens_per_second, model_name)` +- **Purpose**: Calculates MFU percentage based on FLOPs per token and tokens per second +- **Model Size Detection**: Automatically detects model size from model name (7b, 8b, 13b, 32b, 70b) +- **Theoretical Peak FLOPs**: Uses A100 GPU theoretical peak of 312 TFLOPS + +### 3. Token Tracking +- **Location**: `vllm_generate_thread` function +- **Purpose**: Tracks tokens generated by actors for MFU calculation +- **Storage**: Tokens are stored in a shared list for periodic MFU calculation + +### 4. MFU Logging +- **Location**: `one_training_step` function +- **Metrics Logged**: + - `mfu/tokens_per_second`: Tokens generated per second by actors + - `mfu/flops_per_token`: FLOPs per token for the model + - `mfu/mfu_percentage`: Calculated MFU percentage + - `mfu/total_tokens_generated`: Total tokens generated in the period + - `mfu/elapsed_time`: Time elapsed for the calculation period + +## Configuration + +### New Arguments Added to Args Class + +```python +# MFU (Model FLOPs Utilization) settings +enable_mfu_tracking: bool = True +"""Whether to enable MFU tracking""" +mfu_calculation_freq: int = 10 +"""How often to calculate MFU (in training steps)""" +``` + +### Usage + +1. **Enable MFU Tracking** (default: True): + ```bash + python open_instruct/grpo_fast.py --enable_mfu_tracking True + ``` + +2. **Disable MFU Tracking**: + ```bash + python open_instruct/grpo_fast.py --enable_mfu_tracking False + ``` + +3. **Set MFU Calculation Frequency**: + ```bash + python open_instruct/grpo_fast.py --mfu_calculation_freq 5 + ``` + +## Dependencies + +### Required Package +- `fvcore>=0.1.5.post20221221`: For FLOPs counting functionality + +### Installation +```bash +pip install fvcore>=0.1.5.post20221221 +``` + +## Implementation Details + +### 1. FLOPs Calculation Process +1. Creates a dummy input tensor with shape `(1, 1)` (single token) +2. Uses `FlopCountMode` from fvcore to count FLOPs +3. Performs a forward pass with the dummy input +4. Extracts total FLOPs from the flop counter + +### 2. Token Tracking Process +1. In `vllm_generate_thread`, after each generation: + - Counts total tokens generated across all responses + - Appends to shared `mfu_tokens_generated` list +2. In `one_training_step`, periodically: + - Calculates tokens per second from accumulated tokens + - Computes MFU using the calculated FLOPs per token + - Logs metrics to wandb + - Clears the token list for next period + +### 3. Model Size Detection +The system automatically detects model size from the model name: +- Searches for patterns: "70b", "32b", "13b", "8b", "7b" +- Defaults to "7b" if no size is detected +- Uses appropriate theoretical peak FLOPs for the detected size + +## Example Output + +When MFU tracking is enabled, you'll see logs like: +``` +Calculating FLOPs per token for MFU tracking... +FLOPs per token: 1234567890.0 +``` + +And wandb metrics like: +``` +mfu/tokens_per_second: 150.5 +mfu/flops_per_token: 1234567890.0 +mfu/mfu_percentage: 45.2 +mfu/total_tokens_generated: 1505 +mfu/elapsed_time: 10.0 +``` + +## Troubleshooting + +### Common Issues + +1. **fvcore not available**: + ``` + fvcore not available, MFU calculation will be disabled + ``` + - **Solution**: Install fvcore: `pip install fvcore>=0.1.5.post20221221` + +2. **Model size not detected**: + ``` + Could not determine model size from model_name, using 7b as default + ``` + - **Solution**: Ensure model name contains size information (e.g., "llama-7b", "qwen-13b") + +3. **Zero FLOPs per token**: + - **Cause**: Model not properly loaded or device mismatch + - **Solution**: Check model loading and device placement + +## Testing + +Run the test script to verify MFU functionality: +```bash +python test_mfu.py +``` + +## Notes + +- MFU calculation is performed on the first model in the policy group +- FLOPs calculation is done once at the start of training +- Token tracking is thread-safe using a shared list +- MFU metrics are logged to wandb with the "mfu/" prefix +- The system gracefully handles missing dependencies by disabling MFU tracking \ No newline at end of file diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 34582ed467..07799e6926 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -75,6 +75,15 @@ from transformers.integrations import HfDeepSpeedConfig from vllm import SamplingParams +# Import flops calculation utilities +try: + from fvcore.nn import FlopCountMode, flop_count + FLOPS_AVAILABLE = True +except ImportError: + FLOPS_AVAILABLE = False + logger = logging.getLogger(__name__) + logger.warning("fvcore not available, MFU calculation will be disabled") + from open_instruct.dataset_transformation import ( DATASET_SOURCE_KEY, GROUND_TRUTHS_KEY, @@ -134,6 +143,81 @@ INVALID_LOGPROB = 1.0 +def calculate_model_flops_per_token(model: PreTrainedModel, tokenizer: PreTrainedTokenizer) -> float: + """ + Calculate the exact FLOPs for a single token forward pass. + + Args: + model: The model to calculate FLOPs for + tokenizer: The tokenizer for the model + + Returns: + The number of FLOPs per token + """ + if not FLOPS_AVAILABLE: + logger.warning("fvcore not available, cannot calculate FLOPs per token") + return 0.0 + + # Create a dummy input for a single token + dummy_input = torch.randint(0, tokenizer.vocab_size, (1, 1), device=model.device) + attention_mask = torch.ones_like(dummy_input) + + # Calculate FLOPs + with FlopCountMode(model) as flop_counter: + with torch.no_grad(): + _ = model(input_ids=dummy_input, attention_mask=attention_mask) + + flops_dict = flop_counter.get_total_flops() + total_flops = sum(flops_dict.values()) + + logger.info(f"Calculated {total_flops} FLOPs per token") + return total_flops + + +def calculate_mfu(flops_per_token: float, tokens_per_second: float, model_name: str) -> float: + """ + Calculate Model FLOPs Utilization (MFU). + + Args: + flops_per_token: FLOPs per token from the model + tokens_per_second: Tokens per second from the actors + model_name: Name of the model for theoretical peak FLOPs lookup + + Returns: + MFU as a percentage (0-100) + """ + # Theoretical peak FLOPs for different model sizes (in FLOPs/second) + # These are approximate values for A100 GPUs + theoretical_peak_flops = { + "7b": 312e12, # 312 TFLOPS for A100 + "8b": 312e12, + "13b": 312e12, + "32b": 312e12, + "70b": 312e12, + } + + # Try to extract model size from model name + model_size = None + for size in ["70b", "32b", "13b", "8b", "7b"]: + if size in model_name.lower(): + model_size = size + break + + if model_size is None: + logger.warning(f"Could not determine model size from {model_name}, using 7b as default") + model_size = "7b" + + peak_flops = theoretical_peak_flops[model_size] + + # Calculate actual FLOPs per second + actual_flops_per_second = flops_per_token * tokens_per_second + + # Calculate MFU + mfu = (actual_flops_per_second / peak_flops) * 100 + + return mfu + + @dataclass class Args: # Dataset @@ -387,6 +471,12 @@ class Args: # code-tool specific settings code_tool_api_endpoint: Optional[str] = None + # MFU (Model FLOPs Utilization) settings + enable_mfu_tracking: bool = True + """Whether to enable MFU tracking""" + mfu_calculation_freq: int = 10 + """How often to calculate MFU (in training steps)""" + def __post_init__(self): assert self.num_samples_per_prompt_rollout > 0, "Number of samples per prompt must be greater than 0!" if self.num_samples_per_prompt_rollout == 1: @@ -1025,6 +1115,10 @@ def save_model(self, output_dir: str) -> None: # save tokenizer self.tokenizer.save_pretrained(output_dir) + def get_model(self): + """Get the model for FLOPs calculation.""" + return self.policy + # we need this because we don't know which node is rank 0 is on def launch_ai2_evals_on_weka_wrapper(self, step_dir, leaderboard_name, wandb_url, training_step): args = self.args @@ -1105,6 +1199,8 @@ def vllm_generate_thread( eval_freq: int, resume_training_step: int = 1, tool_use: bool = False, + mfu_tracking_enabled: bool = False, + mfu_tokens_generated: Optional[List[int]] = None, ): def generate_with_engines(prompts: List[List[int]], sampling_params: SamplingParams): # Split queries between engines @@ -1161,6 +1257,12 @@ def generate_with_engines(prompts: List[List[int]], sampling_params: SamplingPar with Timer("🔥 Generation time"): response_ids, finish_reasons, masks, info = generate_with_engines(g_queries_list, generation_config) + + # Track tokens generated for MFU calculation + if mfu_tracking_enabled and mfu_tokens_generated is not None: + total_tokens_generated = sum(len(response) for response in response_ids) + mfu_tokens_generated.append(total_tokens_generated) + inference_results_Q.put((response_ids, finish_reasons, masks, info)) # Evaluate the model @@ -1687,6 +1789,9 @@ def one_training_step( train_dataset, writer, wandb_url, + mfu_tokens_generated=None, + flops_per_token=None, + mfu_start_time=None, ): """Train the model for one step.""" update_ref_policy_future = [] @@ -1709,6 +1814,33 @@ def one_training_step( ) average_metrics = {k: sum(m[k] for m in metrics_list) / len(metrics_list) for k in metrics_list[0]} + + # Calculate MFU if enabled + mfu_metrics = {} + if args.enable_mfu_tracking and mfu_tokens_generated is not None and flops_per_token is not None: + if mfu_start_time is None: + mfu_start_time = time.time() + + # Calculate tokens per second from actors + total_tokens_from_actors = sum(mfu_tokens_generated) + elapsed_time = time.time() - mfu_start_time + tokens_per_second = total_tokens_from_actors / elapsed_time if elapsed_time > 0 else 0 + + # Calculate MFU + mfu = calculate_mfu(flops_per_token, tokens_per_second, args.exp_name) + + mfu_metrics = { + "mfu/tokens_per_second": tokens_per_second, + "mfu/flops_per_token": flops_per_token, + "mfu/mfu_percentage": mfu, + "mfu/total_tokens_generated": total_tokens_from_actors, + "mfu/elapsed_time": elapsed_time, + } + + # Clear the tokens list for next calculation + mfu_tokens_generated.clear() + mfu_start_time = time.time() + metrics = { "episode": episode, "training_step": training_step, @@ -1717,6 +1849,7 @@ def one_training_step( "tokens_per_second": num_total_tokens / (time.time() - start_time), **data_thread_metrics, **average_metrics, + **mfu_metrics, } scalar_metrics = {} for key, value in metrics.items(): @@ -1990,6 +2123,11 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa evaluation_inference_results_Q = Queue(maxsize=1) packed_sequences_Q = Queue(maxsize=args.async_steps) queries_prompt_Q = Queue(maxsize=args.async_steps) + + # MFU tracking variables + mfu_tokens_generated = [] if args.enable_mfu_tracking else None + flops_per_token = None + mfu_start_time = None eval_prompt_token_ids = None eval_ground_truths = None @@ -2013,6 +2151,8 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa args.eval_freq, resume_training_step, args.tool_use, + args.enable_mfu_tracking, + mfu_tokens_generated, ), ) generate_thread.start() @@ -2044,6 +2184,19 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa num_total_tokens = 0 start_time = time.time() + + # Calculate FLOPs per token if MFU tracking is enabled + if args.enable_mfu_tracking: + logger.info("Calculating FLOPs per token for MFU tracking...") + # Get the first model from the policy group to calculate FLOPs + first_model = ray.get(policy_group.models[0].get_model.remote()) + flops_per_token = calculate_model_flops_per_token(first_model, tokenizer) + mfu_start_time = time.time() + logger.info(f"FLOPs per token: {flops_per_token}") + else: + flops_per_token = None + mfu_start_time = None + try: for training_step in range(resume_training_step, args.num_training_steps + 1): logger.info("-" * 100) @@ -2083,6 +2236,9 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa train_dataset, writer, wandb_url, + mfu_tokens_generated, + flops_per_token, + mfu_start_time, ) maybe_evaluate( diff --git a/requirements.txt b/requirements.txt index 6b5ac80e6f..452b094f96 100644 --- a/requirements.txt +++ b/requirements.txt @@ -869,3 +869,5 @@ yarl==1.20.0 # via aiohttp zipp==3.22.0 # via importlib-metadata +fvcore>=0.1.5.post20221221 + # via open-instruct diff --git a/test_mfu.py b/test_mfu.py new file mode 100644 index 0000000000..a42ffbae20 --- /dev/null +++ b/test_mfu.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +""" +Test script for MFU (Model FLOPs Utilization) functionality in grpo_fast.py +""" + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +# Import the MFU functions from grpo_fast +from open_instruct.grpo_fast import calculate_model_flops_per_token, calculate_mfu + +def test_mfu_functionality(): + """Test the MFU calculation functionality.""" + print("Testing MFU functionality...") + + # Test with a small model + model_name = "microsoft/DialoGPT-small" # Small model for testing + print(f"Loading model: {model_name}") + + try: + # Load model and tokenizer + model = AutoModelForCausalLM.from_pretrained(model_name) + tokenizer = AutoTokenizer.from_pretrained(model_name) + + # Set model to evaluation mode + model.eval() + + # Test FLOPs calculation + print("Calculating FLOPs per token...") + flops_per_token = calculate_model_flops_per_token(model, tokenizer) + print(f"FLOPs per token: {flops_per_token}") + + # Test MFU calculation + print("Testing MFU calculation...") + tokens_per_second = 100.0 # Example tokens per second + mfu = calculate_mfu(flops_per_token, tokens_per_second, model_name) + print(f"MFU: {mfu:.2f}%") + + print("✅ MFU functionality test passed!") + + except Exception as e: + print(f"❌ MFU functionality test failed: {e}") + return False + + return True + +if __name__ == "__main__": + test_mfu_functionality() \ No newline at end of file