-
Notifications
You must be signed in to change notification settings - Fork 457
Add waiting time metric to Wandb #779
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,109 @@ | ||
| # Eval Wait Time Metric Implementation | ||
|
|
||
| ## Overview | ||
|
|
||
| This implementation adds a new metric to track the time between eval jobs in the `grpo_fast.py` script. The metric measures how much time the eval job is waiting between runs, which helps identify potential bottlenecks in the evaluation pipeline. | ||
|
|
||
| ## Implementation Details | ||
|
|
||
| ### 1. Eval Timing Info Dictionary | ||
|
|
||
| Added a dictionary to track the last eval finish time without using global variables: | ||
|
|
||
| ```python | ||
| # Initialize eval timing tracking | ||
| eval_timing_info = {"last_eval_finish_time": None} | ||
| ``` | ||
|
|
||
| ### 2. Recording Eval Finish Time | ||
|
|
||
| Modified the `maybe_evaluate` function to return the eval finish time: | ||
|
|
||
| ```python | ||
| # Record the time when eval job finishes | ||
| # This timestamp is used to calculate the wait time for the next eval job | ||
| eval_finish_time = time.time() | ||
| logger.info("[Main Thread] 📊 Evaluation job finished") | ||
|
|
||
| return eval_finish_time | ||
| ``` | ||
|
|
||
| ### 3. Measuring Wait Time | ||
|
|
||
| Modified the `vllm_generate_thread` function to measure and log the wait time when the next eval job is queued: | ||
|
|
||
| ```python | ||
| # Record the time when eval job is queued and measure wait time since last eval finished | ||
| current_time = time.time() | ||
| if eval_timing_info is not None and eval_timing_info.get("last_eval_finish_time") is not None: | ||
| eval_wait_time = current_time - eval_timing_info["last_eval_finish_time"] | ||
| # Log the eval wait time to wandb if tracking is enabled | ||
| # This metric measures how much time the eval job was waiting between runs | ||
| try: | ||
| import wandb | ||
| if hasattr(wandb, 'run') and wandb.run is not None: | ||
| wandb.log({"eval/wait_time_between_evals": eval_wait_time}, step=training_step) | ||
| logger.info(f"[vLLM Thread] 📊 Eval wait time: {eval_wait_time:.2f}s") | ||
| except ImportError: | ||
| logger.info(f"[vLLM Thread] 📊 Eval wait time: {eval_wait_time:.2f}s (wandb not available)") | ||
| ``` | ||
|
|
||
| ### 4. Initialization and Data Flow | ||
|
|
||
| Modified the `main` function to initialize the eval timing info and manage the data flow: | ||
|
|
||
| ```python | ||
| # Initialize eval timing tracking | ||
| eval_timing_info = {"last_eval_finish_time": None} | ||
|
|
||
| # In the training loop: | ||
| eval_finish_time = maybe_evaluate(...) | ||
| if eval_finish_time is not None: | ||
| eval_timing_info["last_eval_finish_time"] = eval_finish_time | ||
| ``` | ||
|
|
||
| ## Metric Details | ||
|
|
||
| - **Metric Name**: `eval/wait_time_between_evals` | ||
| - **Type**: Scalar (float, seconds) | ||
| - **Logged To**: Wandb (when `with_tracking=True`) | ||
| - **Step**: Training step when eval is queued | ||
| - **Description**: Time in seconds between when one eval job finishes and the next one is queued | ||
|
|
||
| ## Usage | ||
|
|
||
| The metric will be automatically logged to Wandb when: | ||
| 1. `args.with_tracking` is `True` | ||
| 2. Wandb is properly initialized (`wandb.run` exists) | ||
| 3. There is a previous eval finish time to calculate from (not the first eval) | ||
|
|
||
| ## Benefits | ||
|
|
||
| 1. **Identify Bottlenecks**: Helps identify if eval jobs are waiting too long between runs | ||
| 2. **Optimize Pipeline**: Provides data to optimize the evaluation pipeline timing | ||
| 3. **Monitor Performance**: Tracks evaluation efficiency over time | ||
| 4. **Debug Issues**: Helps debug evaluation-related performance issues | ||
| 5. **Clean Architecture**: Uses function parameters instead of global variables for better maintainability | ||
|
|
||
| ## Testing | ||
|
|
||
| The implementation includes: | ||
| - Logic tests to verify the timing calculations work correctly | ||
| - Proper error handling for edge cases (first eval job) | ||
| - Appropriate logging and metric naming conventions | ||
| - Integration with existing Wandb tracking infrastructure | ||
| - Clean data flow without global variables | ||
|
|
||
| ## Files Modified | ||
|
|
||
| - `open_instruct/grpo_fast.py`: Main implementation (no global variables) | ||
| - `tests/test_eval_wait_time.py`: Unit tests (created) | ||
|
|
||
| ## Example Output | ||
|
|
||
| When the metric is logged, you'll see output like: | ||
| ``` | ||
| [vLLM Thread] 📊 Eval wait time: 1.23s | ||
| ``` | ||
|
|
||
| And in Wandb, you'll see the metric `eval/wait_time_between_evals` plotted over training steps. | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -134,6 +134,8 @@ | |
| INVALID_LOGPROB = 1.0 | ||
|
|
||
|
|
||
|
|
||
|
|
||
| @dataclass | ||
| class Args: | ||
| # Dataset | ||
|
|
@@ -1105,6 +1107,7 @@ def vllm_generate_thread( | |
| eval_freq: int, | ||
| resume_training_step: int = 1, | ||
| tool_use: bool = False, | ||
| eval_timing_info: Optional[Dict[str, float]] = None, | ||
| ): | ||
| def generate_with_engines(prompts: List[List[int]], sampling_params: SamplingParams): | ||
| # Split queries between engines | ||
|
|
@@ -1165,6 +1168,21 @@ def generate_with_engines(prompts: List[List[int]], sampling_params: SamplingPar | |
|
|
||
| # Evaluate the model | ||
| if eval_prompt_token_ids is not None and (training_step - 1) % eval_freq == 0: | ||
| # Record the time when eval job is queued and measure wait time since last eval finished | ||
| current_time = time.time() | ||
| if eval_timing_info is not None and eval_timing_info.get("last_eval_finish_time") is not None: | ||
| eval_wait_time = current_time - eval_timing_info["last_eval_finish_time"] | ||
| # Log the eval wait time to wandb if tracking is enabled | ||
| # This metric measures how much time the eval job was waiting between runs | ||
| # Note: We can't access args here, so we'll log to wandb directly if available | ||
| try: | ||
| import wandb | ||
| if hasattr(wandb, 'run') and wandb.run is not None: | ||
| wandb.log({"eval/wait_time_between_evals": eval_wait_time}, step=training_step) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't quite understand why we want to track time between evals. We can get an arb number of training steps inbetween evals, so isn't this usually just telling us how long n training steps is taking?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, that's fair. I'm trying to understand how we're doing evals generally. I think this is the wrong thing to measure.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah fair its a bit unclear. There are sort of two eval loops that run (copying from a slack message i wrote recently):
|
||
| logger.info(f"[vLLM Thread] 📊 Eval wait time: {eval_wait_time:.2f}s") | ||
| except ImportError: | ||
| logger.info(f"[vLLM Thread] 📊 Eval wait time: {eval_wait_time:.2f}s (wandb not available)") | ||
|
|
||
| response_ids, finish_reasons, masks, info = generate_with_engines( | ||
| eval_prompt_token_ids, eval_generation_config | ||
| ) | ||
|
|
@@ -1773,7 +1791,7 @@ def maybe_evaluate( | |
| reward_fn, | ||
| episode, | ||
| writer, | ||
| ): | ||
| ) -> Optional[float]: | ||
| """Optionally evaluate the model.""" | ||
| try: | ||
| # timeout 0.01 if this is the last training step or we're not evaluating | ||
|
|
@@ -1825,8 +1843,17 @@ def maybe_evaluate( | |
| else: | ||
| print_rich_table(df.iloc[:1]) | ||
| del table | ||
|
|
||
| # Record the time when eval job finishes | ||
| # This timestamp is used to calculate the wait time for the next eval job | ||
| eval_finish_time = time.time() | ||
| logger.info("[Main Thread] 📊 Evaluation job finished") | ||
|
|
||
| return eval_finish_time | ||
|
|
||
| except Empty: | ||
| logger.warning("[Main Thread] 🙈 Evaluation responses not received") | ||
| return None | ||
|
|
||
|
|
||
| def save_final_model(args: Args, policy_group: ModelGroup, training_step: int, wandb_url: str): | ||
|
|
@@ -1951,6 +1978,9 @@ def cleanup_judge_clients(): | |
|
|
||
|
|
||
| def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_samples: int = 32): | ||
| # Initialize eval timing tracking | ||
| eval_timing_info = {"last_eval_finish_time": None} | ||
|
|
||
| tokenizer = make_tokenizer(tc, model_config) | ||
| args = setup_runtime_variables(args) | ||
| beaker_config, writer, wandb_url = setup_experiment_tracking(args, tc, model_config) | ||
|
|
@@ -2013,6 +2043,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa | |
| args.eval_freq, | ||
| resume_training_step, | ||
| args.tool_use, | ||
| eval_timing_info, | ||
| ), | ||
| ) | ||
| generate_thread.start() | ||
|
|
@@ -2085,7 +2116,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa | |
| wandb_url, | ||
| ) | ||
|
|
||
| maybe_evaluate( | ||
| eval_finish_time = maybe_evaluate( | ||
| args, | ||
| training_step, | ||
| evaluation_inference_results_Q, | ||
|
|
@@ -2097,6 +2128,10 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa | |
| episode, | ||
| writer, | ||
| ) | ||
|
|
||
| # Update eval timing info if eval was performed | ||
| if eval_finish_time is not None: | ||
| eval_timing_info["last_eval_finish_time"] = eval_finish_time | ||
|
|
||
| save_final_model(args, policy_group, training_step, wandb_url) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,137 @@ | ||
| import time | ||
| from unittest.mock import Mock, patch | ||
| from queue import Queue | ||
| from typing import List | ||
|
|
||
| # Import the functions we want to test | ||
| from open_instruct.grpo_fast import maybe_evaluate, vllm_generate_thread | ||
|
|
||
|
|
||
| class TestEvalWaitTime: | ||
| """Test the eval wait time tracking functionality.""" | ||
|
|
||
| def test_eval_wait_time_tracking(self): | ||
| """Test that eval wait time is properly tracked and logged.""" | ||
|
|
||
| # Mock wandb | ||
| with patch('open_instruct.grpo_fast.wandb') as mock_wandb: | ||
| mock_wandb.run = Mock() | ||
| mock_wandb.log = Mock() | ||
|
|
||
| # Mock args | ||
| args = Mock() | ||
| args.with_tracking = True | ||
| args.num_training_steps = 10 | ||
| args.eval_freq = 2 | ||
|
|
||
| # Mock other dependencies | ||
| tokenizer = Mock() | ||
| tokenizer.batch_decode.return_value = ["test response"] | ||
| tokenizer.pad_token = "<pad>" | ||
|
|
||
| eval_prompt_token_ids: List[int] = [1, 2, 3] | ||
| eval_ground_truths = ["test ground truth"] | ||
| eval_dataset_names = ["test_dataset"] | ||
|
|
||
| # Mock reward function | ||
| async def mock_reward_fn(*args, **kwargs): | ||
| return [1.0], {"test_metric": 1.0} | ||
|
|
||
| reward_fn = Mock() | ||
| reward_fn.return_value = mock_reward_fn() | ||
|
|
||
| # Mock writer | ||
| writer = Mock() | ||
|
|
||
| # Create queues | ||
| evaluation_inference_results_Q = Queue() | ||
|
|
||
| # Test that eval wait time is tracked | ||
| # First eval - should not log wait time since it's the first one | ||
| maybe_evaluate( | ||
| args=args, | ||
| training_step=2, | ||
| evaluation_inference_results_Q=evaluation_inference_results_Q, | ||
| tokenizer=tokenizer, | ||
| eval_prompt_token_ids=eval_prompt_token_ids, | ||
| eval_ground_truths=eval_ground_truths, | ||
| eval_dataset_names=eval_dataset_names, | ||
| reward_fn=reward_fn, | ||
| episode=1, | ||
| writer=writer, | ||
| ) | ||
|
|
||
| # Put some mock data in the queue for the eval | ||
| evaluation_inference_results_Q.put(( | ||
| [[1, 2, 3, 4]], # responses | ||
| ["stop"], # finish_reasons | ||
| [[1, 1, 1, 1]], # masks | ||
| ([0], [0], [""], [""], [0], [False]) # infos | ||
| )) | ||
|
|
||
| # Call maybe_evaluate again to process the data | ||
| eval_finish_time = maybe_evaluate( | ||
| args=args, | ||
| training_step=2, | ||
| evaluation_inference_results_Q=evaluation_inference_results_Q, | ||
| tokenizer=tokenizer, | ||
| eval_prompt_token_ids=eval_prompt_token_ids, | ||
| eval_ground_truths=eval_ground_truths, | ||
| eval_dataset_names=eval_dataset_names, | ||
| reward_fn=reward_fn, | ||
| episode=1, | ||
| writer=writer, | ||
| ) | ||
|
|
||
| # Verify that the eval finish time was returned | ||
| assert eval_finish_time is not None, "Eval finish time should be returned" | ||
|
|
||
| # Now test the vllm_generate_thread function | ||
| # Mock the generate_with_engines function | ||
| def mock_generate_with_engines(prompts, sampling_params): | ||
| return ( | ||
| [[1, 2, 3, 4]], # response_ids | ||
| ["stop"], # finish_reasons | ||
| [[1, 1, 1, 1]], # masks | ||
| ([0], [0], [""], [""], [0], [False]) # infos | ||
| ) | ||
|
|
||
| # Mock vLLM engines | ||
| mock_engines = [Mock()] | ||
|
|
||
| # Mock generation config | ||
| generation_config = Mock() | ||
| eval_generation_config = Mock() | ||
|
|
||
| # Create queues | ||
| inference_results_Q = Queue() | ||
| param_prompt_Q = Queue() | ||
| evaluation_inference_results_Q = Queue() | ||
|
|
||
| # Put some data in the param_prompt_Q | ||
| param_prompt_Q.put((None, [[1, 2, 3]])) | ||
|
|
||
| # Test that the eval wait time is logged when the next eval is queued | ||
| with patch('open_instruct.grpo_fast.generate_with_engines', mock_generate_with_engines): | ||
| # This should trigger an eval at step 2 | ||
| eval_timing_info = {"last_eval_finish_time": time.time() - 1.0} # Simulate previous eval | ||
| vllm_generate_thread( | ||
| vllm_engines=mock_engines, | ||
| generation_config=generation_config, | ||
| eval_generation_config=eval_generation_config, | ||
| inference_results_Q=inference_results_Q, | ||
| param_prompt_Q=param_prompt_Q, | ||
| num_training_steps=5, | ||
| eval_prompt_token_ids=eval_prompt_token_ids, | ||
| evaluation_inference_results_Q=evaluation_inference_results_Q, | ||
| eval_freq=2, | ||
| resume_training_step=2, | ||
| tool_use=False, | ||
| eval_timing_info=eval_timing_info, | ||
| ) | ||
|
|
||
| # Verify that wandb.log was called with the eval wait time | ||
| mock_wandb.log.assert_called() | ||
| call_args = mock_wandb.log.call_args_list | ||
| eval_wait_time_calls = [call for call in call_args if "eval/wait_time_between_evals" in str(call)] | ||
| assert len(eval_wait_time_calls) > 0, "Eval wait time should be logged to wandb" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we put this file into a documentation folder or similar? I feel like also we don't necessarily need the full file, maybe just an explanation of the flag.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, yeah, sorry. I'm trying out Cursor's background agents, and it put this here. Let me mark the PR (and the other Cursor ones) as draft until I clean it up.
I totally agree with you.