Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions EVAL_WAIT_TIME_IMPLEMENTATION.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Eval Wait Time Metric Implementation
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we put this file into a documentation folder or similar? I feel like also we don't necessarily need the full file, maybe just an explanation of the flag.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yeah, sorry. I'm trying out Cursor's background agents, and it put this here. Let me mark the PR (and the other Cursor ones) as draft until I clean it up.

I totally agree with you.


## Overview

This implementation adds a new metric to track the time between eval jobs in the `grpo_fast.py` script. The metric measures how much time the eval job is waiting between runs, which helps identify potential bottlenecks in the evaluation pipeline.

## Implementation Details

### 1. Eval Timing Info Dictionary

Added a dictionary to track the last eval finish time without using global variables:

```python
# Initialize eval timing tracking
eval_timing_info = {"last_eval_finish_time": None}
```

### 2. Recording Eval Finish Time

Modified the `maybe_evaluate` function to return the eval finish time:

```python
# Record the time when eval job finishes
# This timestamp is used to calculate the wait time for the next eval job
eval_finish_time = time.time()
logger.info("[Main Thread] 📊 Evaluation job finished")

return eval_finish_time
```

### 3. Measuring Wait Time

Modified the `vllm_generate_thread` function to measure and log the wait time when the next eval job is queued:

```python
# Record the time when eval job is queued and measure wait time since last eval finished
current_time = time.time()
if eval_timing_info is not None and eval_timing_info.get("last_eval_finish_time") is not None:
eval_wait_time = current_time - eval_timing_info["last_eval_finish_time"]
# Log the eval wait time to wandb if tracking is enabled
# This metric measures how much time the eval job was waiting between runs
try:
import wandb
if hasattr(wandb, 'run') and wandb.run is not None:
wandb.log({"eval/wait_time_between_evals": eval_wait_time}, step=training_step)
logger.info(f"[vLLM Thread] 📊 Eval wait time: {eval_wait_time:.2f}s")
except ImportError:
logger.info(f"[vLLM Thread] 📊 Eval wait time: {eval_wait_time:.2f}s (wandb not available)")
```

### 4. Initialization and Data Flow

Modified the `main` function to initialize the eval timing info and manage the data flow:

```python
# Initialize eval timing tracking
eval_timing_info = {"last_eval_finish_time": None}

# In the training loop:
eval_finish_time = maybe_evaluate(...)
if eval_finish_time is not None:
eval_timing_info["last_eval_finish_time"] = eval_finish_time
```

## Metric Details

- **Metric Name**: `eval/wait_time_between_evals`
- **Type**: Scalar (float, seconds)
- **Logged To**: Wandb (when `with_tracking=True`)
- **Step**: Training step when eval is queued
- **Description**: Time in seconds between when one eval job finishes and the next one is queued

## Usage

The metric will be automatically logged to Wandb when:
1. `args.with_tracking` is `True`
2. Wandb is properly initialized (`wandb.run` exists)
3. There is a previous eval finish time to calculate from (not the first eval)

## Benefits

1. **Identify Bottlenecks**: Helps identify if eval jobs are waiting too long between runs
2. **Optimize Pipeline**: Provides data to optimize the evaluation pipeline timing
3. **Monitor Performance**: Tracks evaluation efficiency over time
4. **Debug Issues**: Helps debug evaluation-related performance issues
5. **Clean Architecture**: Uses function parameters instead of global variables for better maintainability

## Testing

The implementation includes:
- Logic tests to verify the timing calculations work correctly
- Proper error handling for edge cases (first eval job)
- Appropriate logging and metric naming conventions
- Integration with existing Wandb tracking infrastructure
- Clean data flow without global variables

## Files Modified

- `open_instruct/grpo_fast.py`: Main implementation (no global variables)
- `tests/test_eval_wait_time.py`: Unit tests (created)

## Example Output

When the metric is logged, you'll see output like:
```
[vLLM Thread] 📊 Eval wait time: 1.23s
```

And in Wandb, you'll see the metric `eval/wait_time_between_evals` plotted over training steps.
39 changes: 37 additions & 2 deletions open_instruct/grpo_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@
INVALID_LOGPROB = 1.0




@dataclass
class Args:
# Dataset
Expand Down Expand Up @@ -1105,6 +1107,7 @@ def vllm_generate_thread(
eval_freq: int,
resume_training_step: int = 1,
tool_use: bool = False,
eval_timing_info: Optional[Dict[str, float]] = None,
):
def generate_with_engines(prompts: List[List[int]], sampling_params: SamplingParams):
# Split queries between engines
Expand Down Expand Up @@ -1165,6 +1168,21 @@ def generate_with_engines(prompts: List[List[int]], sampling_params: SamplingPar

# Evaluate the model
if eval_prompt_token_ids is not None and (training_step - 1) % eval_freq == 0:
# Record the time when eval job is queued and measure wait time since last eval finished
current_time = time.time()
if eval_timing_info is not None and eval_timing_info.get("last_eval_finish_time") is not None:
eval_wait_time = current_time - eval_timing_info["last_eval_finish_time"]
# Log the eval wait time to wandb if tracking is enabled
# This metric measures how much time the eval job was waiting between runs
# Note: We can't access args here, so we'll log to wandb directly if available
try:
import wandb
if hasattr(wandb, 'run') and wandb.run is not None:
wandb.log({"eval/wait_time_between_evals": eval_wait_time}, step=training_step)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite understand why we want to track time between evals. We can get an arb number of training steps inbetween evals, so isn't this usually just telling us how long n training steps is taking?

Copy link
Collaborator Author

@finbarrtimbers finbarrtimbers Jul 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, that's fair. I'm trying to understand how we're doing evals generally. I think this is the wrong thing to measure.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah fair its a bit unclear. There are sort of two eval loops that run (copying from a slack message i wrote recently):

  1. in-loop evals, which just re-use the generation and reward/verifier code. I’m not crazy about these but they can be useful to observe how model generations change over training, and can give you an idea of val reward. Controlled by num_evals, which works out the total number of training steps via the episode count and then sets eval_freq accordingly. https://github.com/allenai/open-instruct/blob/main/open_instruct/grpo_fast.py#L1459. I’m actually not really a fan of this way of doing it but don’t feel strongly enough to change it haha.

  2. oe-eval evals, which are downstream and launched as separate jobs, and the actual final numbers we usually care about. This is tied to save_freq AND requires you setting try_launch_beaker_eval_jobs_on_weka to True (even if on augusta, we should rename this arg). We could add a further check at e.g. https://github.com/allenai/open-instruct/blob/main/open_instruct/grpo_fast.py#L1785. Note that we can’t really untie this since the oe-eval jobs need some sort of path.

logger.info(f"[vLLM Thread] 📊 Eval wait time: {eval_wait_time:.2f}s")
except ImportError:
logger.info(f"[vLLM Thread] 📊 Eval wait time: {eval_wait_time:.2f}s (wandb not available)")

response_ids, finish_reasons, masks, info = generate_with_engines(
eval_prompt_token_ids, eval_generation_config
)
Expand Down Expand Up @@ -1773,7 +1791,7 @@ def maybe_evaluate(
reward_fn,
episode,
writer,
):
) -> Optional[float]:
"""Optionally evaluate the model."""
try:
# timeout 0.01 if this is the last training step or we're not evaluating
Expand Down Expand Up @@ -1825,8 +1843,17 @@ def maybe_evaluate(
else:
print_rich_table(df.iloc[:1])
del table

# Record the time when eval job finishes
# This timestamp is used to calculate the wait time for the next eval job
eval_finish_time = time.time()
logger.info("[Main Thread] 📊 Evaluation job finished")

return eval_finish_time

except Empty:
logger.warning("[Main Thread] 🙈 Evaluation responses not received")
return None


def save_final_model(args: Args, policy_group: ModelGroup, training_step: int, wandb_url: str):
Expand Down Expand Up @@ -1951,6 +1978,9 @@ def cleanup_judge_clients():


def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_samples: int = 32):
# Initialize eval timing tracking
eval_timing_info = {"last_eval_finish_time": None}

tokenizer = make_tokenizer(tc, model_config)
args = setup_runtime_variables(args)
beaker_config, writer, wandb_url = setup_experiment_tracking(args, tc, model_config)
Expand Down Expand Up @@ -2013,6 +2043,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa
args.eval_freq,
resume_training_step,
args.tool_use,
eval_timing_info,
),
)
generate_thread.start()
Expand Down Expand Up @@ -2085,7 +2116,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa
wandb_url,
)

maybe_evaluate(
eval_finish_time = maybe_evaluate(
args,
training_step,
evaluation_inference_results_Q,
Expand All @@ -2097,6 +2128,10 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa
episode,
writer,
)

# Update eval timing info if eval was performed
if eval_finish_time is not None:
eval_timing_info["last_eval_finish_time"] = eval_finish_time

save_final_model(args, policy_group, training_step, wandb_url)

Expand Down
137 changes: 137 additions & 0 deletions tests/test_eval_wait_time.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import time
from unittest.mock import Mock, patch
from queue import Queue
from typing import List

# Import the functions we want to test
from open_instruct.grpo_fast import maybe_evaluate, vllm_generate_thread


class TestEvalWaitTime:
"""Test the eval wait time tracking functionality."""

def test_eval_wait_time_tracking(self):
"""Test that eval wait time is properly tracked and logged."""

# Mock wandb
with patch('open_instruct.grpo_fast.wandb') as mock_wandb:
mock_wandb.run = Mock()
mock_wandb.log = Mock()

# Mock args
args = Mock()
args.with_tracking = True
args.num_training_steps = 10
args.eval_freq = 2

# Mock other dependencies
tokenizer = Mock()
tokenizer.batch_decode.return_value = ["test response"]
tokenizer.pad_token = "<pad>"

eval_prompt_token_ids: List[int] = [1, 2, 3]
eval_ground_truths = ["test ground truth"]
eval_dataset_names = ["test_dataset"]

# Mock reward function
async def mock_reward_fn(*args, **kwargs):
return [1.0], {"test_metric": 1.0}

reward_fn = Mock()
reward_fn.return_value = mock_reward_fn()

# Mock writer
writer = Mock()

# Create queues
evaluation_inference_results_Q = Queue()

# Test that eval wait time is tracked
# First eval - should not log wait time since it's the first one
maybe_evaluate(
args=args,
training_step=2,
evaluation_inference_results_Q=evaluation_inference_results_Q,
tokenizer=tokenizer,
eval_prompt_token_ids=eval_prompt_token_ids,
eval_ground_truths=eval_ground_truths,
eval_dataset_names=eval_dataset_names,
reward_fn=reward_fn,
episode=1,
writer=writer,
)

# Put some mock data in the queue for the eval
evaluation_inference_results_Q.put((
[[1, 2, 3, 4]], # responses
["stop"], # finish_reasons
[[1, 1, 1, 1]], # masks
([0], [0], [""], [""], [0], [False]) # infos
))

# Call maybe_evaluate again to process the data
eval_finish_time = maybe_evaluate(
args=args,
training_step=2,
evaluation_inference_results_Q=evaluation_inference_results_Q,
tokenizer=tokenizer,
eval_prompt_token_ids=eval_prompt_token_ids,
eval_ground_truths=eval_ground_truths,
eval_dataset_names=eval_dataset_names,
reward_fn=reward_fn,
episode=1,
writer=writer,
)

# Verify that the eval finish time was returned
assert eval_finish_time is not None, "Eval finish time should be returned"

# Now test the vllm_generate_thread function
# Mock the generate_with_engines function
def mock_generate_with_engines(prompts, sampling_params):
return (
[[1, 2, 3, 4]], # response_ids
["stop"], # finish_reasons
[[1, 1, 1, 1]], # masks
([0], [0], [""], [""], [0], [False]) # infos
)

# Mock vLLM engines
mock_engines = [Mock()]

# Mock generation config
generation_config = Mock()
eval_generation_config = Mock()

# Create queues
inference_results_Q = Queue()
param_prompt_Q = Queue()
evaluation_inference_results_Q = Queue()

# Put some data in the param_prompt_Q
param_prompt_Q.put((None, [[1, 2, 3]]))

# Test that the eval wait time is logged when the next eval is queued
with patch('open_instruct.grpo_fast.generate_with_engines', mock_generate_with_engines):
# This should trigger an eval at step 2
eval_timing_info = {"last_eval_finish_time": time.time() - 1.0} # Simulate previous eval
vllm_generate_thread(
vllm_engines=mock_engines,
generation_config=generation_config,
eval_generation_config=eval_generation_config,
inference_results_Q=inference_results_Q,
param_prompt_Q=param_prompt_Q,
num_training_steps=5,
eval_prompt_token_ids=eval_prompt_token_ids,
evaluation_inference_results_Q=evaluation_inference_results_Q,
eval_freq=2,
resume_training_step=2,
tool_use=False,
eval_timing_info=eval_timing_info,
)

# Verify that wandb.log was called with the eval wait time
mock_wandb.log.assert_called()
call_args = mock_wandb.log.call_args_list
eval_wait_time_calls = [call for call in call_args if "eval/wait_time_between_evals" in str(call)]
assert len(eval_wait_time_calls) > 0, "Eval wait time should be logged to wandb"
Loading