diff --git a/examples/Intellect-3.1/rl.toml b/examples/Intellect-3.1/rl.toml index f4ea5ec525..b16e45349c 100644 --- a/examples/Intellect-3.1/rl.toml +++ b/examples/Intellect-3.1/rl.toml @@ -50,30 +50,34 @@ oversampling_factor = 2 [[orchestrator.env]] id = "mini-swe-agent-plus" name = "swe" +ratio = 0.3 args = { max_turns = 200, cpu_cores = 2, memory_gb = 4, disk_size_gb = 4, labels = ["mini-swe-agent-plus"], total_timeout_minutes = 720, sandbox_client_max_workers = 256, max_command_timeouts = 3, sandbox_command_timeout = 30} [[orchestrator.env]] id = "deepdive" name = "deepdive" +ratio = 0.2 args = { finish_with_tool = true, open_max_workers = 128, cache_dir = "/tmp/i3_deepdive_cache_train" } [[orchestrator.env]] id = "math-env" name = "math" +ratio = 0.3 args = { min_avg_reward = 0.0, max_avg_reward = 0.874} [[orchestrator.env]] id = "logic-env" -args = { min_avg_reward = 0.0, max_avg_reward = 0.874 } name = "logic" +ratio = 0.2 +args = { min_avg_reward = 0.0, max_avg_reward = 0.874 } [[orchestrator.env]] id = "code-env" name = "code" +ratio = 0.2 args = { pool_size = 512 } [orchestrator.buffer] -env_ratios = [0.3, 0.2, 0.3, 0.2, 0.2] easy_threshold = 1.0 online_difficulty_filtering = true seed = 42 diff --git a/src/prime_rl/configs/orchestrator.py b/src/prime_rl/configs/orchestrator.py index f9c126ebac..f352f5a8b7 100644 --- a/src/prime_rl/configs/orchestrator.py +++ b/src/prime_rl/configs/orchestrator.py @@ -268,47 +268,69 @@ class EvalSaveHFConfig(BaseConfig): class EnvConfig(BaseConfig): - """Configures an environment for training.""" + """Base environment configuration.""" id: Annotated[str, Field(description="ID of the environment to use.")] = "reverse-text" args: Annotated[dict, Field(description="Arguments to pass to the environment.")] = {} name: Annotated[str | None, Field(description="Name of the environment to use.")] = None - address: Annotated[ - str | None, - Field( - description="Address of the environment server. If None, will spawn an environment server in a subprocess automatically.If given, will try to connect an environment client to the environment server at this address." - ), - ] = None extra_env_kwargs: Annotated[ dict[str, Any], Field( description=( - "Extra kwargs passed to an env (e.g. seq_len, score_rollouts). This field is auto-populated with the seq_len, and score_rollouts for training envs on the orchestrator. It is generally NOT recommended for this field to be overriden by the user. It's main use case is to match the extra_env_kwargs when running an env in an isolated environment server." + "Extra kwargs passed to an env (e.g. seq_len, score_rollouts). This field is auto-populated " + "by the orchestrator. It is generally NOT recommended for this field to be overridden by the user." ), ), ] = {} + + address: Annotated[ + str | None, + Field( + description="Address of the environment server. If None, will spawn an environment server in a subprocess automatically. If given, will connect an environment client to the server at this address." + ), + ] = None + num_workers: Annotated[ int | Literal["auto"], Field( - description=( - "Number of env server worker processes. " - "Set to 'auto' to scale based on the env's concurrency (1 worker per 256 concurrent rollouts). " - "When setting manually, we recommend sizing so that each worker handles at most 256 concurrent rollouts. " - "Only used when the orchestrator spawns the env server (i.e. address is None)." - ), + description="Number of env server worker processes. " + "Set to 'auto' to scale based on the env's concurrency (1 worker per 256 concurrent rollouts). " + "When setting manually, we recommend sizing so that each worker handles at most 256 concurrent rollouts. " + "Only used when the orchestrator spawns the env server (i.e. address is None)." ), ] = "auto" + + score_rollouts: Annotated[ + bool, + Field( + description="Whether to score rollouts using the environment rubric. If False, rewards are always set to 0.", + ), + ] = True + + ratio: Annotated[ + float | None, + Field( + gt=0, + description="Sampling ratio for this environment in the buffer. If None for all envs, samples uniformly across all available problems. If set, ratios across all envs must sum to 1.", + ), + ] = None + max_retries: Annotated[ int, Field( ge=0, - description="Maximum number of times the environment will retry a failed rollout.", + description="Maximum number of internal retries the environment will attempt before declaring a rollout failed.", ), ] = 0 + @property + def stripped_id(self) -> str: + """Environment ID without the @version suffix.""" + return self.id.split("@")[0] + @property def resolved_name(self) -> str: - return self.name or self.id.split("@")[0] + return self.name or self.stripped_id @model_validator(mode="after") def validate_env_name(self): @@ -319,54 +341,71 @@ def validate_env_name(self): return self +TrainEnvConfig = EnvConfig + + class EvalEnvConfig(EnvConfig): - """Configures an environment for evaluation.""" + """Configures an evaluation environment.""" num_examples: Annotated[ - int | None, - Field( - description="Number of examples to evaluate per environment. If not set, will use 'num_examples' from main config." - ), - ] = None - rollouts_per_example: Annotated[ - int | None, - Field( - description="Number of samples to generate per example for each environment. If not set, will use 'rollouts_per_example' from main config." - ), - ] = None + int, + Field(description="Number of examples to evaluate."), + ] = -1 - skip_first: Annotated[ + rollouts_per_example: Annotated[ int, - Field( - description="Number of examples to skip from the beginning of the dataset.", - ), - ] = 0 + Field(ge=1, description="Number of rollouts to generate per example."), + ] = 1 -class ValConfig(BaseConfig): - """Configures the validation of the model.""" +class TrainEnvsConfig(BaseConfig): + """Configures all training environments.""" - num_examples: Annotated[ - int, Field(ge=1, description="Number of examples to use for validation. If -1, will use all examples.") - ] = 16 - rollouts_per_example: Annotated[ - int, Field(ge=1, description="Number of samples to generate per example for validation.") - ] = 1 - interval: Annotated[int, Field(description="Interval at which to validate the model.")] = 10 + env: list[TrainEnvConfig] = [TrainEnvConfig()] + @model_validator(mode="after") + def validate_unique_env_names(self): + env_names = [env.resolved_name for env in self.env] + duplicates = [n for n in env_names if env_names.count(n) > 1] + if duplicates: + raise ValueError( + f"Duplicate training environment names: {set(duplicates)}. Each env must have a unique name." + ) + return self -class EvalConfig(BaseConfig): - """Configures evaluation using verifiers environments.""" + @model_validator(mode="after") + def validate_env_ratios(self): + ratios = [env.ratio for env in self.env] + if all(r is None for r in ratios): + return self + if any(r is None for r in ratios): + raise ValueError("Either all envs must have a ratio or none of them. Got a mix of set and unset ratios.") + return self + + +class EvalEnvsConfig(BaseConfig): + """Configures all evaluation environments.""" env: list[EvalEnvConfig] = [EvalEnvConfig()] + + @model_validator(mode="after") + def validate_unique_env_names(self): + env_names = [env.resolved_name for env in self.env] + duplicates = [n for n in env_names if env_names.count(n) > 1] + if duplicates: + raise ValueError( + f"Duplicate evaluation environment names: {set(duplicates)}. Each env must have a unique name." + ) + return self + + +class EvalConfig(EvalEnvsConfig): + """Configures evaluation using verifiers environments.""" + sampling: EvalSamplingConfig = Field( default_factory=EvalSamplingConfig, description="Shared sampling configuration for evals; can differ from training sampling.", ) - num_examples: Annotated[int, Field(description="Number of examples to evaluate per environment.")] = -1 - rollouts_per_example: Annotated[ - int, Field(ge=1, description="Number of samples to generate per example for each environment.") - ] = 1 interval: Annotated[ int, @@ -401,14 +440,6 @@ class EvalConfig(BaseConfig): ), ] = False - @model_validator(mode="after") - def validate_unique_env_names(self): - env_names = [env.resolved_name for env in self.env] - duplicates = [n for n in env_names if env_names.count(n) > 1] - if duplicates: - raise ValueError(f"Duplicate eval environment names: {set(duplicates)}. Each env must have a unique name.") - return self - class CheckpointConfig(BaseConfig): """Configures checkpointing the orchestrator.""" @@ -472,16 +503,6 @@ class BufferConfig(BaseConfig): ), ] = None - env_ratios: Annotated[ - list[float] | None, - Field( - description=( - "Ratios for sampling from each environment. " - "If None, samples uniformly across all available problems (not environments)." - ), - ), - ] = None - easy_threshold: Annotated[ float | None, Field( @@ -535,12 +556,6 @@ def validate_thresholds(self): assert self.easy_threshold > self.hard_threshold, "easy_threshold must be greater than hard_threshold." return self - @model_validator(mode="after") - def validate_env_ratios(self): - if self.env_ratios is not None: - assert all(ratio > 0 for ratio in self.env_ratios), "All env_ratios must be positive." - return self - class VerificationConfig(BaseConfig): """Configures rollout verification and rubric scoring.""" @@ -703,7 +718,7 @@ class TeacherRolloutModelConfig(BaseConfig): ] = ModelConfig() -class OrchestratorConfig(BaseConfig): +class OrchestratorConfig(TrainEnvsConfig): """Configures the orchestrator for RL training.""" # The OAI client configuration @@ -739,9 +754,6 @@ class OrchestratorConfig(BaseConfig): # The sampling configuration sampling: SamplingConfig = SamplingConfig() - # The environment configuration - env: list[EnvConfig] = [EnvConfig()] - # The evaluation configuration eval: EvalConfig | None = None @@ -769,9 +781,6 @@ class OrchestratorConfig(BaseConfig): # The checkpoint configuration ckpt: CheckpointConfig | None = None - # The validation configuration - val: ValConfig | None = None - weight_broadcast: WeightBroadcastConfig = FileSystemWeightBroadcastConfig() rollout_transport: TransportConfig = FileSystemTransportConfig() @@ -783,13 +792,6 @@ class OrchestratorConfig(BaseConfig): ), ] = Path("outputs/run_default") - max_concurrent: Annotated[ - int | None, - Field( - description="Maximum number of concurrent rollouts to generate and score per-environment. If None, will not limit concurrency.", - ), - ] = None - tasks_per_minute: Annotated[ int | None, Field( @@ -908,6 +910,16 @@ class OrchestratorConfig(BaseConfig): ), ] = True + @property + def train_envs(self) -> list[TrainEnvConfig]: + return self.env + + @property + def eval_envs(self) -> list[EvalEnvConfig]: + if self.eval is None: + return [] + return self.eval.env + @model_validator(mode="after") def validate_unique_filter_types(self): types = [f.type for f in self.filters] @@ -915,12 +927,6 @@ def validate_unique_filter_types(self): raise ValueError(f"Duplicate filter types: {types}. Each filter type may only appear once.") return self - @model_validator(mode="after") - def validate_max_concurrent(self): - if self.max_concurrent is not None and self.max_concurrent < self.rollouts_per_example: - raise ValueError("max_concurrent must be at least the number of rollouts per example") - return self - @model_validator(mode="after") def nccl_max_async_level(self): if self.weight_broadcast.type == "nccl": @@ -960,20 +966,6 @@ def resolve_batching(self): raise ValueError("max_inflight_rollouts must be at least the number of rollouts per example") return self - @model_validator(mode="after") - def validate_unique_env_names(self): - env_names = [env.resolved_name for env in self.env] - duplicates = [n for n in env_names if env_names.count(n) > 1] - if duplicates: - raise ValueError(f"Duplicate environment names: {set(duplicates)}. Each env must have a unique name.") - return self - - @model_validator(mode="after") - def validate_env_ratios(self): - if self.buffer.env_ratios is not None: - assert len(self.buffer.env_ratios) == len(self.env), "env_ratios length must match number of environments" - return self - @model_validator(mode="after") def validate_verification_config(self): if self.verification.enabled: @@ -1019,15 +1011,13 @@ def auto_setup_bench(self): return self @model_validator(mode="after") - def resolve_extra_env_kwargs(self): - train_extra_env_kwargs = dict( - max_seq_len=self.seq_len, - score_rollouts=self.verification.enabled, - ) + def resolve_env_config(self): + """Populate extra_env_kwargs from top-level and per-env fields.""" for env in self.env: - # extra_env_kwargs is not meant to be used by the user, we shamelessly override here - env.extra_env_kwargs.update(train_extra_env_kwargs) - + env.extra_env_kwargs.update( + max_seq_len=self.seq_len, + score_rollouts=env.score_rollouts, + ) return self @model_validator(mode="after") diff --git a/src/prime_rl/orchestrator/buffer.py b/src/prime_rl/orchestrator/buffer.py index 5b659b3467..54a27d6cb8 100644 --- a/src/prime_rl/orchestrator/buffer.py +++ b/src/prime_rl/orchestrator/buffer.py @@ -1,10 +1,12 @@ +from __future__ import annotations + import hashlib import json import random from collections import defaultdict from functools import partial from pathlib import Path -from typing import cast +from typing import TYPE_CHECKING, cast import verifiers as vf from datasets import Dataset @@ -14,307 +16,308 @@ from prime_rl.utils.logger import get_logger from prime_rl.utils.utils import format_num, mean, mean_normalize +if TYPE_CHECKING: + from prime_rl.orchestrator.envs import Envs -class Buffer: - """A buffer for storing rollouts and metadata.""" - POOLS = ["easy", "normal", "hard"] +POOLS = ["easy", "normal", "hard"] - def __init__( - self, - dataset: Dataset, - env_names: list[str], - buffer_config: BufferConfig, - ): - self.dataset = dataset - self.env_names = env_names - self.config = buffer_config - self.logger = get_logger() - if self.config.seed is not None: - random.seed(self.config.seed) - - # Basic assertions - assert "example_id" in self.dataset.column_names, "The dataset must contain a `example_id` column." - assert "prompt" in self.dataset.column_names, "The dataset must contain a `prompt` column." - assert "task" in self.dataset.column_names, "The dataset must contain a `task` column." - assert len(self.dataset) > 0, "The dataset must contain at least one example." - assert isinstance(self.dataset["example_id"][0], int), "The `example_id` column must be of type int." - assert len(set(self.dataset["example_id"])) == len(self.dataset), "The `example_id` column must be unique." - assert set(self.dataset["task"]) == set(self.env_names), "The `task` column must contain all environment names." - - # Initialize example buffer (env_name -> (example_id -> example)) - self.example_buffer: dict[str, dict[int, dict]] = defaultdict(dict) - for example in map(partial(cast, dict), self.dataset): - self.example_buffer[example["task"]][example["example_id"]] = example - assert len(self.example_buffer) == len(self.env_names) - self.logger.debug( - f"Initialized buffer with {format_num(len(self.dataset), precision=0)} example(s) in {len(self.env_names)} environment(s)" - ) +class EnvBuffer: + """Per-environment buffer managing examples and difficulty pools.""" - if self.config.env_ratios is not None: - # Convert ratios to probabilities - env_ratio = mean_normalize(self.config.env_ratios) - self.env_probs = {env_name: ratio for env_name, ratio in zip(self.env_names, env_ratio)} - self.logger.debug( - f"Sampling buffer according to provided environment ratios ({', '.join(f'{k}={v:.2f}' for k, v in self.env_probs.items())})" - ) - else: - # Count examples per environment to sample according to natural env distribution - env_counts = [len(self.example_buffer[env_name]) for env_name in self.env_names] - env_ratio = mean_normalize(env_counts) - self.env_probs = {env_name: ratio for env_name, ratio in zip(self.env_names, env_ratio)} - self.logger.debug( - f"Sampling buffer according to natural environment distribution ({', '.join(f'{k}={v:.2f}' for k, v in self.env_probs.items())})" - ) + def __init__(self, env_name: str, dataset: Dataset, config: BufferConfig): + self.env_name = env_name + self.config = config + + assert len(dataset) > 0, f"Dataset for {env_name} must contain at least one example." + assert "example_id" in dataset.column_names, f"Dataset for {env_name} must contain an `example_id` column." + assert "prompt" in dataset.column_names, f"Dataset for {env_name} must contain a `prompt` column." + + self.examples: dict[int, dict] = {} + for example in map(partial(cast, dict), dataset): + example["env_name"] = env_name + example["task"] = env_name # for vf.RolloutInput compat + self.examples[example["example_id"]] = example - # Initialize buffers for easy/ hard examples self.easy_examples: list[dict] = [] self.hard_examples: list[dict] = [] - # Initialize rollout buffer (flat list of rollouts) - self.rollout_buffer: list[vf.RolloutOutput] = [] - self.reset_step_metrics() + @property + def num_normal(self) -> int: + return len(self.examples) + + @property + def num_total(self) -> int: + return self.num_normal + len(self.easy_examples) + len(self.hard_examples) + + def sample_example(self) -> dict: + return random.choice(list(self.examples.values())) + def get_example_hash(self, example: dict) -> str: - """Returns a hash of the example based on hash keys.""" hash_keys = [key for key in self.config.hash_keys if key in example] assert hash_keys, "No hashable keys found in example." return hashlib.sha256(json.dumps([example[key] for key in hash_keys]).encode()).hexdigest() - def save(self, path: Path) -> None: - """Saves pool assignments and rollout buffer.""" - path.mkdir(parents=True, exist_ok=True) + def update_pools(self, example_id: int, avg_reward: float, num_rollouts: int) -> list[str]: + """Assign example to pool based on reward. Returns pool name.""" + if self.config.easy_threshold is not None and avg_reward >= self.config.easy_threshold: + pool = "easy" + elif self.config.hard_threshold is not None and avg_reward <= self.config.hard_threshold: + pool = "hard" + else: + pool = "normal" - def write_jsonl(lst: list, path: Path) -> None: - with open(path, "w") as f: - for item in lst: - f.write(json.dumps(item, default=make_serializable) + "\n") + if pool != "normal" and example_id in self.examples: + example = self.examples.pop(example_id) + target = self.easy_examples if pool == "easy" else self.hard_examples + target.append(example) - write_jsonl(self.easy_examples, path / "easy_examples.jsonl") - write_jsonl(self.hard_examples, path / "hard_examples.jsonl") - write_jsonl(self.rollout_buffer, path / "rollout_buffer.jsonl") + self.num_examples_per_step[pool] += 1 + return pool - def load(self, path: Path) -> None: - """Loads pool assignments and rollouts.""" + def reset_step_metrics(self) -> None: + zero = lambda: {p: 0 for p in POOLS} + self.num_examples_per_step = zero() + self.num_rollouts_per_step = zero() - def read_jsonl(path: Path) -> list[dict]: - with open(path, "r") as f: - return [json.loads(line) for line in f] + def get_metrics(self) -> dict[str, float]: + metrics = {} + num_examples = sum(self.num_examples_per_step.values()) + num_rollouts = sum(self.num_rollouts_per_step.values()) - saved_easy_examples = read_jsonl(path / "easy_examples.jsonl") - saved_hard_examples = read_jsonl(path / "hard_examples.jsonl") - saved_rollout_buffer = cast(list[vf.RolloutOutput], read_jsonl(path / "rollout_buffer.jsonl")) - - if any(saved_easy_examples) or any(saved_hard_examples) or any(saved_rollout_buffer): - # Build hash lookup for example buffer (env -> (example_hash -> example_id)) - example_hash_lookup = defaultdict(dict) - all_hashes = set() - for env in self.example_buffer: - for example_id, example in self.example_buffer[env].items(): - example_hash = self.get_example_hash(example) - if example_hash in all_hashes: - self.logger.warning( - f"Duplicate example hash found based on hash_keys={self.config.hash_keys}. Overwriting with latest example. This may cause unexpected behavior when resuming the buffer." - ) - example_hash_lookup[env][example_hash] = example_id - all_hashes.add(example_hash) - - def move_saved_pool(saved_examples: list[dict], target_pool: list[dict]) -> int: - """Moves saved examples to the target pool from example buffer based on hash lookup.""" - num_moved = 0 - for example in saved_examples: - example_hash = self.get_example_hash(example) - for env in example_hash_lookup: - if example_hash in example_hash_lookup[env]: - example_id = example_hash_lookup[env][example_hash] - example = self.example_buffer[env].pop(example_id, None) - if example is not None: - target_pool.append(example) - num_moved += 1 - break - return num_moved - - if any(saved_easy_examples): - num_moved = move_saved_pool(saved_easy_examples, self.easy_examples) - self.logger.debug( - f"Loaded {num_moved}/{len(saved_easy_examples)} example(s) to easy pool from checkpoint." - ) - if num_moved != len(saved_easy_examples): - num_not_moved = len(saved_easy_examples) - num_moved - self.logger.warning( - f"Could not move {num_not_moved} example(s) from checkpoint to easy pool. This usually means you resumed with an env mix that does not contain all previous examples." - ) + for pool in ["easy", "hard"]: + if num_examples: + metrics[f"evicted_examples/{self.env_name}/{pool}"] = self.num_examples_per_step[pool] / num_examples + if num_rollouts: + metrics[f"filtered_rollouts/{self.env_name}/{pool}"] = self.num_rollouts_per_step[pool] / num_rollouts - if any(saved_hard_examples): - num_moved = move_saved_pool(saved_hard_examples, self.hard_examples) - self.logger.debug( - f"Moved {num_moved}/{len(saved_hard_examples)} example(s) to hard pool from checkpoint." - ) - if num_moved != len(saved_hard_examples): - num_not_moved = len(saved_hard_examples) - num_moved - self.logger.warning( - f"Could not move {num_not_moved} example(s) from checkpoint to hard pool. This usually means you resumed with an env mix that does not contain all previous examples." - ) + pool_counts = [len(self.easy_examples), self.num_normal, len(self.hard_examples)] + pool_ratios = mean_normalize(pool_counts) + for pool, ratio in zip(POOLS, pool_ratios): + metrics[f"pool/{self.env_name}/{pool}"] = ratio - if any(saved_rollout_buffer): - # Extend rollout buffer, but only include rollouts for which the example still exists in the example buffer - valid_saved_rollouts = [ - rollout for rollout in saved_rollout_buffer if rollout["task"] in self.env_names - ] - self.rollout_buffer.extend(valid_saved_rollouts) - self.logger.debug(f"Loaded {len(valid_saved_rollouts)} rollout(s) from checkpoint.") - - # Load rollouts, filtering out removed environments and problems - def convert_examples_to_normal(examples: list[dict], fraction: float) -> int: - """Moves a fraction of examples from the given pool back to normal.""" - if fraction <= 0.0 or not examples: - return 0 - num_moved = round(len(examples) * fraction) - if num_moved <= 0: - return 0 - for _ in range(num_moved): - example = random.choice(examples) - env_name = example["task"] - example_id = example["example_id"] - examples.remove(example) - self.example_buffer[env_name][example_id] = example - return num_moved - - num_easy_examples = len(self.easy_examples) - num_moved = convert_examples_to_normal(self.easy_examples, self.config.easy_fraction) - self.logger.debug(f"Converted {num_moved}/{num_easy_examples} example(s) back to normal from easy pool.") - num_hard_examples = len(self.hard_examples) - num_moved = convert_examples_to_normal(self.hard_examples, self.config.hard_fraction) - self.logger.debug(f"Converted {num_moved}/{num_hard_examples} example(s) back to normal from hard pool.") - else: - self.logger.debug("No easy/ hard examples or rollouts found in checkpoint") + self.reset_step_metrics() + return metrics - def sample_examples(self, n: int) -> list[dict]: - """Samples n examples from the buffer, respecting env ratios.""" - non_empty_envs = [env for env, examples in self.example_buffer.items() if examples] +class BufferSet: + """Manages multiple EnvBuffers with env-ratio-aware sampling.""" - if not non_empty_envs: - raise ValueError("No environments left with examples.") + def __init__(self, envs: Envs, config: BufferConfig): + self.config = config + self.logger = get_logger() - non_empty_env_probs = [self.env_probs[env] for env in non_empty_envs] - sampled_examples = [] - for sampled_env in random.choices(non_empty_envs, weights=non_empty_env_probs, k=n): - sampled_example = random.choice(list(self.example_buffer[sampled_env].values())) - sampled_examples.append(sampled_example) + if config.seed is not None: + random.seed(config.seed) - return sampled_examples + self.env_buffers: dict[str, EnvBuffer] = {} + for env in envs: + ds = env.get_dataset(seed=config.seed) + if "example_id" not in ds.column_names: + ds = ds.map(lambda ex, idx: {**ex, "example_id": idx}, with_indices=True) + self.env_buffers[env.name] = EnvBuffer(env.name, ds, config) + self.env_names = envs.names - def update(self, rollouts: list[vf.RolloutOutput]): - """Updates the buffer state with completed rollouts.""" + total = sum(eb.num_total for eb in self.env_buffers.values()) + self.logger.debug( + f"Initialized buffer with {format_num(total, precision=0)} example(s) " + f"in {len(self.env_names)} environment(s)" + ) + + env_ratios = [env.config.ratio for env in envs] + if any(r is not None for r in env_ratios): + env_ratio = mean_normalize(env_ratios) + self.env_probs = dict(zip(self.env_names, env_ratio)) + self.logger.debug( + f"Sampling buffer according to provided environment ratios " + f"({', '.join(f'{k}={v:.2f}' for k, v in self.env_probs.items())})" + ) + else: + env_counts = [self.env_buffers[name].num_normal for name in self.env_names] + env_ratio = mean_normalize(env_counts) + self.env_probs = dict(zip(self.env_names, env_ratio)) + self.logger.debug( + f"Sampling buffer according to natural environment distribution " + f"({', '.join(f'{k}={v:.2f}' for k, v in self.env_probs.items())})" + ) + + self.rollout_buffer: list[vf.RolloutOutput] = [] + def sample_examples(self, n: int) -> list[dict]: + """Samples n examples across envs, respecting env ratios.""" + non_empty = [name for name, eb in self.env_buffers.items() if eb.examples] + if not non_empty: + raise ValueError("No environments left with examples.") + + weights = [self.env_probs[name] for name in non_empty] + return [self.env_buffers[name].sample_example() for name in random.choices(non_empty, weights=weights, k=n)] + + def update(self, rollouts: list[vf.RolloutOutput]): + """Updates buffer state with completed rollouts.""" rollouts_by_example = defaultdict(list) for rollout in rollouts: - rollouts_by_example[rollout["example_id"]].append(rollout) + rollouts_by_example[(rollout["env_name"], rollout["example_id"])].append(rollout) - for example_id, example_rollouts in rollouts_by_example.items(): + for (env_name, example_id), example_rollouts in rollouts_by_example.items(): + eb = self.env_buffers[env_name] avg_reward = mean([r["reward"] for r in example_rollouts]) - env_name = example_rollouts[0]["task"] + eb.update_pools(example_id, avg_reward, len(example_rollouts)) - if self.config.easy_threshold is not None and avg_reward >= self.config.easy_threshold: - pool = "easy" - elif self.config.hard_threshold is not None and avg_reward <= self.config.hard_threshold: - pool = "hard" - else: - pool = "normal" - - if pool != "normal" and example_id in self.example_buffer[env_name]: - example = self.example_buffer[env_name].pop(example_id) - target_pool = self.easy_examples if pool == "easy" else self.hard_examples - target_pool.append(example) - - self.num_examples_per_step[env_name][pool] += 1 if self.config.online_difficulty_filtering: if avg_reward == 0.0: - self.num_rollouts_per_step[env_name]["hard"] += len(example_rollouts) + eb.num_rollouts_per_step["hard"] += len(example_rollouts) continue elif avg_reward == 1.0: - self.num_rollouts_per_step[env_name]["easy"] += len(example_rollouts) + eb.num_rollouts_per_step["easy"] += len(example_rollouts) continue - self.num_rollouts_per_step[env_name]["normal"] += len(example_rollouts) + eb.num_rollouts_per_step["normal"] += len(example_rollouts) self.rollout_buffer.extend(example_rollouts) def sample_rollouts(self, n: int) -> list[vf.RolloutOutput]: """Samples the latest n rollouts from the buffer.""" n = min(n, len(self.rollout_buffer)) - sampled_rollouts = self.rollout_buffer[-n:] + sampled = self.rollout_buffer[-n:] self.rollout_buffer = self.rollout_buffer[:-n] - return sampled_rollouts + return sampled - def reset_step_metrics(self) -> None: - """Reset per-step metrics (called after get_metrics).""" - zero_per_pool = lambda: {p: 0 for p in self.POOLS} - # num examples per env per step per pool (env_name -> (pool -> num_examples)) - self.num_examples_per_step = {env: zero_per_pool() for env in self.env_names} - # num rollouts per env per step per pool (env_name -> (pool -> num_rollouts)) - self.num_rollouts_per_step = {env: zero_per_pool() for env in self.env_names} + def save(self, path: Path) -> None: + """Saves pool assignments and rollout buffer.""" + path.mkdir(parents=True, exist_ok=True) - def get_metrics(self) -> dict[str, float]: - """Returns the buffer metrics for the current step.""" + def write_jsonl(lst: list, filepath: Path) -> None: + with open(filepath, "w") as f: + for item in lst: + f.write(json.dumps(item, default=make_serializable) + "\n") - metrics = {} - easy_examples_per_env = defaultdict(int) - hard_examples_per_env = defaultdict(int) - for example in self.easy_examples: - easy_examples_per_env[example["task"]] += 1 - for example in self.hard_examples: - hard_examples_per_env[example["task"]] += 1 - - # sum over envs (e.g. log globally) - num_examples_per_step_per_pool = { - pool: sum(self.num_examples_per_step[env][pool] for env in self.env_names) for pool in self.POOLS - } - num_rollouts_per_step_per_pool = { - pool: sum(self.num_rollouts_per_step[env][pool] for env in self.env_names) for pool in self.POOLS - } - num_examples_per_step = sum(num_examples_per_step_per_pool.values()) - num_rollouts_per_step = sum(num_rollouts_per_step_per_pool.values()) + all_easy = [ex for eb in self.env_buffers.values() for ex in eb.easy_examples] + all_hard = [ex for eb in self.env_buffers.values() for ex in eb.hard_examples] + write_jsonl(all_easy, path / "easy_examples.jsonl") + write_jsonl(all_hard, path / "hard_examples.jsonl") + write_jsonl(self.rollout_buffer, path / "rollout_buffer.jsonl") - for pool in ["easy", "hard"]: - if num_examples_per_step: - metrics[f"evicted_examples/{pool}"] = num_examples_per_step_per_pool[pool] / num_examples_per_step - if num_rollouts_per_step: - metrics[f"filtered_rollouts/{pool}"] = num_rollouts_per_step_per_pool[pool] / num_rollouts_per_step + def load(self, path: Path) -> None: + """Loads pool assignments and rollouts from checkpoint.""" - total_normal = sum(len(self.example_buffer[env]) for env in self.env_names) - pool_counts = [len(self.easy_examples), total_normal, len(self.hard_examples)] - pool_ratios = mean_normalize(pool_counts) - for pool, pool_ratio in zip(self.POOLS, pool_ratios): - metrics[f"pool/{pool}"] = pool_ratio - - for env in self.env_names: - env_num_examples_per_step_per_pool = self.num_examples_per_step[env] - env_num_rollouts_per_step_per_pool = self.num_rollouts_per_step[env] - env_num_examples_per_step = sum(env_num_examples_per_step_per_pool.values()) - env_num_rollouts_per_step = sum(env_num_rollouts_per_step_per_pool.values()) - - for pool in ["easy", "hard"]: - if env_num_examples_per_step: - metrics[f"evicted_examples/{env}/{pool}"] = ( - env_num_examples_per_step_per_pool[pool] / env_num_examples_per_step - ) - if env_num_rollouts_per_step: - metrics[f"filtered_rollouts/{env}/{pool}"] = ( - env_num_rollouts_per_step_per_pool[pool] / env_num_rollouts_per_step + def read_jsonl(filepath: Path) -> list[dict]: + with open(filepath, "r") as f: + return [json.loads(line) for line in f] + + saved_easy = read_jsonl(path / "easy_examples.jsonl") + saved_hard = read_jsonl(path / "hard_examples.jsonl") + saved_rollouts = cast(list[vf.RolloutOutput], read_jsonl(path / "rollout_buffer.jsonl")) + + if not any(saved_easy) and not any(saved_hard) and not any(saved_rollouts): + self.logger.debug("No easy/ hard examples or rollouts found in checkpoint") + return + + # Build hash lookup across all env buffers: env -> (hash -> example_id) + hash_lookup: dict[str, dict[str, int]] = defaultdict(dict) + all_hashes: set[str] = set() + for env_name, eb in self.env_buffers.items(): + for example_id, example in eb.examples.items(): + h = eb.get_example_hash(example) + if h in all_hashes: + self.logger.warning( + f"Duplicate example hash found based on hash_keys={self.config.hash_keys}. " + "Overwriting with latest example. This may cause unexpected behavior when resuming the buffer." ) + hash_lookup[env_name][h] = example_id + all_hashes.add(h) + + def move_saved_pool(saved_examples: list[dict], pool_name: str) -> int: + num_moved = 0 + for example in saved_examples: + # Use any env buffer to compute hash (hash_keys are config-level) + first_eb = next(iter(self.env_buffers.values())) + h = first_eb.get_example_hash(example) + for env_name, env_hashes in hash_lookup.items(): + if h in env_hashes: + example_id = env_hashes[h] + eb = self.env_buffers[env_name] + matched = eb.examples.pop(example_id, None) + if matched is not None: + target = eb.easy_examples if pool_name == "easy" else eb.hard_examples + target.append(matched) + num_moved += 1 + break + return num_moved + + if any(saved_easy): + num_moved = move_saved_pool(saved_easy, "easy") + self.logger.debug(f"Loaded {num_moved}/{len(saved_easy)} example(s) to easy pool from checkpoint.") + if num_moved != len(saved_easy): + self.logger.warning( + f"Could not move {len(saved_easy) - num_moved} example(s) from checkpoint to easy pool. " + "This usually means you resumed with an env mix that does not contain all previous examples." + ) + + if any(saved_hard): + num_moved = move_saved_pool(saved_hard, "hard") + self.logger.debug(f"Moved {num_moved}/{len(saved_hard)} example(s) to hard pool from checkpoint.") + if num_moved != len(saved_hard): + self.logger.warning( + f"Could not move {len(saved_hard) - num_moved} example(s) from checkpoint to hard pool. " + "This usually means you resumed with an env mix that does not contain all previous examples." + ) - env_pool_counts = [ - easy_examples_per_env[env], - len(self.example_buffer[env]), - hard_examples_per_env[env], - ] - env_pool_ratios = mean_normalize(env_pool_counts) - for pool, pool_ratio in zip(self.POOLS, env_pool_ratios): - metrics[f"pool/{env}/{pool}"] = pool_ratio + if any(saved_rollouts): + valid = [r for r in saved_rollouts if r.get("env_name", r.get("task")) in self.env_names] + self.rollout_buffer.extend(valid) + self.logger.debug(f"Loaded {len(valid)} rollout(s) from checkpoint.") + + def convert_to_normal(eb: EnvBuffer, pool: list[dict], fraction: float) -> int: + if fraction <= 0.0 or not pool: + return 0 + num_to_move = round(len(pool) * fraction) + if num_to_move <= 0: + return 0 + for _ in range(num_to_move): + example = random.choice(pool) + pool.remove(example) + eb.examples[example["example_id"]] = example + return num_to_move + + for eb in self.env_buffers.values(): + n_easy = len(eb.easy_examples) + moved = convert_to_normal(eb, eb.easy_examples, self.config.easy_fraction) + self.logger.debug(f"Converted {moved}/{n_easy} example(s) back to normal from easy pool ({eb.env_name}).") + n_hard = len(eb.hard_examples) + moved = convert_to_normal(eb, eb.hard_examples, self.config.hard_fraction) + self.logger.debug(f"Converted {moved}/{n_hard} example(s) back to normal from hard pool ({eb.env_name}).") - self.reset_step_metrics() + def get_metrics(self) -> dict[str, float]: + metrics = {} + + # Aggregate cross-env totals + total_examples_per_pool = {p: 0 for p in POOLS} + total_rollouts_per_pool = {p: 0 for p in POOLS} + for eb in self.env_buffers.values(): + for p in POOLS: + total_examples_per_pool[p] += eb.num_examples_per_step[p] + total_rollouts_per_pool[p] += eb.num_rollouts_per_step[p] + + total_examples = sum(total_examples_per_pool.values()) + total_rollouts = sum(total_rollouts_per_pool.values()) + + for pool in ["easy", "hard"]: + if total_examples: + metrics[f"evicted_examples/{pool}"] = total_examples_per_pool[pool] / total_examples + if total_rollouts: + metrics[f"filtered_rollouts/{pool}"] = total_rollouts_per_pool[pool] / total_rollouts + + total_normal = sum(eb.num_normal for eb in self.env_buffers.values()) + total_easy = sum(len(eb.easy_examples) for eb in self.env_buffers.values()) + total_hard = sum(len(eb.hard_examples) for eb in self.env_buffers.values()) + pool_ratios = mean_normalize([total_easy, total_normal, total_hard]) + for pool, ratio in zip(POOLS, pool_ratios): + metrics[f"pool/{pool}"] = ratio + + # Per-env metrics + for eb in self.env_buffers.values(): + metrics.update(eb.get_metrics()) return metrics diff --git a/src/prime_rl/orchestrator/ckpt.py b/src/prime_rl/orchestrator/ckpt.py index 19277e3063..7df20360fc 100644 --- a/src/prime_rl/orchestrator/ckpt.py +++ b/src/prime_rl/orchestrator/ckpt.py @@ -5,7 +5,7 @@ import torch from prime_rl.configs.orchestrator import CheckpointConfig -from prime_rl.orchestrator.buffer import Buffer +from prime_rl.orchestrator.buffer import BufferSet as Buffer from prime_rl.utils.logger import get_logger from prime_rl.utils.utils import get_ckpt_dir, get_step_path diff --git a/src/prime_rl/orchestrator/env_server/env_server.py b/src/prime_rl/orchestrator/env_server/env_server.py index 64e67c783a..0e989cc863 100644 --- a/src/prime_rl/orchestrator/env_server/env_server.py +++ b/src/prime_rl/orchestrator/env_server/env_server.py @@ -7,7 +7,7 @@ from prime_rl.utils.logger import setup_logger from prime_rl.utils.pathing import get_log_dir from prime_rl.utils.process import set_proc_title -from prime_rl.utils.utils import clean_exit, get_env_ids_to_install, install_env, strip_env_version +from prime_rl.utils.utils import clean_exit, get_env_ids_to_install, install_env @clean_exit @@ -23,7 +23,7 @@ def run_server(config: EnvServerConfig): log_dir = (get_log_dir(config.output_dir) / config.env.resolved_name).as_posix() server = ZMQEnvServer( - env_id=strip_env_version(config.env.id), + env_id=config.env.stripped_id, env_args=config.env.args, extra_env_kwargs=config.env.extra_env_kwargs, log_level=config.log.level, diff --git a/src/prime_rl/orchestrator/envs.py b/src/prime_rl/orchestrator/envs.py index 2416a0c1f0..4f23cbb9ac 100644 --- a/src/prime_rl/orchestrator/envs.py +++ b/src/prime_rl/orchestrator/envs.py @@ -1,35 +1,387 @@ -from typing import TYPE_CHECKING, Any +from __future__ import annotations -from prime_rl.utils.envs import _ENV_PARSERS as _BASE_ENV_PARSERS, get_env_value, get_dir, set_defaults +import asyncio +import atexit +import multiprocessing as mp +import time +from collections.abc import Awaitable, Callable, Sequence +from multiprocessing.process import BaseProcess +from pathlib import Path -if TYPE_CHECKING: - # Enable type checking for shared envs - # ruff: noqa - from prime_rl.utils.envs import * +import pandas as pd +import verifiers as vf +from verifiers.serve import ZMQEnvClient, ZMQEnvServer +from verifiers.utils.serve_utils import get_free_port - # vLLM - VLLM_CONFIGURE_LOGGING: int +from prime_rl.configs.orchestrator import EnvConfig, EvalEnvConfig, TrainEnvConfig +from prime_rl.orchestrator.eval_utils import compute_pass_at_k +from prime_rl.orchestrator.vf_utils import get_completion_len, resolve_num_workers +from prime_rl.utils.logger import ProgressTracker, get_logger +from prime_rl.utils.monitor import get_monitor +from prime_rl.utils.utils import capitalize - # tqdm - TQDM_DISABLE: int +REQUIRED_STATE_COLUMNS = ["trajectory", "sampling_args"] -_ORCHESTRATOR_ENV_PARSERS = { - "VLLM_CONFIGURE_LOGGING": int, - "TQDM_DISABLE": int, - **_BASE_ENV_PARSERS, -} +class Env: + """Wraps a vf.Environment - only exposes features used in PRIME-RL.""" -_ORCHESTRATOR_ENV_DEFAULTS = { - "VLLM_CONFIGURE_LOGGING": "0", -} + def __init__(self, config: EnvConfig): + self.config = config + self.sampling_args: dict = {} -set_defaults(_ORCHESTRATOR_ENV_DEFAULTS) + self._env: vf.Environment = vf.load_environment(config.stripped_id, **config.args) + self._env_client: ZMQEnvClient | None = None + self._env_server_process: BaseProcess | None = None + @property + def name(self) -> str: + return self.config.resolved_name -def __getattr__(name: str) -> Any: - return get_env_value(_ORCHESTRATOR_ENV_PARSERS, name) + @property + def env(self) -> vf.Environment: + return self._env + @property + def env_client(self) -> ZMQEnvClient: + if not self._env_client: + raise RuntimeError( + f"Env {self.name} has no env client connected. Call connect() first to connect to an env server." + ) + return self._env_client -def __dir__() -> list[str]: - return get_dir(_ORCHESTRATOR_ENV_PARSERS) + @property + def requires_group_scoring(self) -> bool: + return any(self.env.rubric._is_group_func(func) for func in self.env.rubric._get_reward_funcs()) + + def get_dataset(self, seed: int | None = None): + return self.env.get_dataset(seed=seed) + + def spawn( + self, + log_dir: Path, + max_concurrent: int | None = None, + log_level: str | None = None, + json_logging: bool = False, + ) -> None: + """Spawn an env server if no explicit address is configured.""" + if self.config.address is not None: + return + num_workers = resolve_num_workers(self.config.num_workers, max_concurrent) + address = f"tcp://127.0.0.1:{get_free_port()}" + process = mp.get_context("spawn").Process( + target=ZMQEnvServer.run_server, + args=( + self.config.stripped_id, + self.config.args, + self.config.extra_env_kwargs, + log_level, + (log_dir / self.name).as_posix(), + ), + kwargs=dict( + address=address, + json_logging=json_logging, + console_logging=False, + num_workers=num_workers, + ), + daemon=False, + ) + process.start() + self.config.address = address + self._env_server_process = process + get_logger().info(f"Spawned env server for {self.name} with {num_workers} worker(s)") + + async def connect(self) -> None: + """Connect an env client to the server and assign it.""" + if self.config.address is None: + raise RuntimeError( + f"Env {self.name} has no address configured. Call spawn() first or set address in config." + ) + get_logger().info(f"Connecting env {self.name} to server at {self.config.address}") + self._env_client = ZMQEnvClient(address=self.config.address, name=self.name) + await self.env_client.wait_for_server_startup() + + async def run_rollout( + self, + client: vf.ClientConfig, + example: vf.RolloutInput, + model_name: str, + ) -> vf.RolloutOutput: + return await self.env.run_rollout( + example, + client=client, + model=model_name, + sampling_args=self.sampling_args, + max_retries=self.config.max_retries, + state_columns=REQUIRED_STATE_COLUMNS, + env_client=self.env_client, + ) + + async def run_group( + self, + client: vf.ClientConfig, + example: vf.RolloutInput, + model_name: str, + rollouts_per_example: int, + ) -> list[vf.RolloutOutput]: + return await self.env.run_group( + [example for _ in range(rollouts_per_example)], + client=client, + model=model_name, + sampling_args=self.sampling_args, + max_retries=self.config.max_retries, + state_columns=REQUIRED_STATE_COLUMNS, + env_client=self.env_client, + ) + + def shutdown(self) -> None: + if self._env_server_process is None: + return + logger = get_logger() + self._env_server_process.terminate() + self._env_server_process.join(timeout=25) + if self._env_server_process.is_alive(): + logger.warning(f"Env server {self._env_server_process.pid} did not exit after 25s, force killing") + self._env_server_process.kill() + self._env_server_process.join(timeout=5) + self._env_server_process = None + + +class TrainEnv(Env): + config: TrainEnvConfig + + def __init__(self, config: TrainEnvConfig): + super().__init__(config) + + +class EvalEnv(Env): + config: EvalEnvConfig + + def __init__(self, config: EvalEnvConfig): + super().__init__(config) + + def get_dataset(self, seed: int | None = None): + return self.env.get_eval_dataset(seed=seed) + + async def evaluate( + self, + model_name: str, + get_client: Callable[[], Awaitable[vf.ClientConfig]], + ckpt_step: int, + step: int, + ) -> None: + n, k = self.config.num_examples, self.config.rollouts_per_example + get_logger().info(f"Evaluating {self.name} (num_examples={n}, rollouts_per_example={k})") + + inputs = self.env._get_eval_inputs(n, k) + pbar = ProgressTracker(total=len(inputs), desc=f"Evaluating {self.name}") + eval_start = time.perf_counter() + + if self.requires_group_scoring: + # Group scoring: batch inputs into groups of k and run as groups + groups = [inputs[i : i + k] for i in range(0, len(inputs), k)] + + async def _run_group(group: list) -> list[vf.RolloutOutput] | None: + try: + client = await get_client() + group_inputs = [vf.RolloutInput(**ex) for ex in group] + result = await self._env.run_group( + group_inputs, + client=client, + model=model_name, + sampling_args=self.sampling_args, + max_retries=self.config.max_retries, + state_columns=REQUIRED_STATE_COLUMNS, + env_client=self.env_client, + ) + pbar.update(len(group)) + return result + except Exception as e: + get_logger().warning(f"Group failed: {e}") + pbar.update(len(group)) + return None + + try: + group_results = await asyncio.gather(*[_run_group(g) for g in groups]) + finally: + pbar.close() + + successful_outputs = [o for group in group_results if group is not None for o in group] + failed_count = sum(len(g) for g, r in zip(groups, group_results) if r is None) + else: + # Individual scoring: run each input independently + async def _run_rollout(example: dict) -> vf.RolloutOutput | None: + try: + client = await get_client() + output = await self.run_rollout(client=client, example=example, model_name=model_name) + pbar.update(1) + return output + except Exception as e: + get_logger().warning(f"Rollout failed: {e}") + pbar.update(1) + return None + + try: + outputs = await asyncio.gather(*[_run_rollout(example) for example in inputs]) + finally: + pbar.close() + + successful_outputs = [o for o in outputs if o is not None] + failed_count = sum(1 for o in outputs if o is None) + + eval_time = time.perf_counter() - eval_start + + total_count = len(inputs) + if failed_count: + get_logger().warning( + f"{failed_count}/{total_count} ({failed_count / total_count * 100:.1f}%) rollouts failed" + ) + + if not successful_outputs: + get_logger().warning(f"All rollouts failed for {self.name}, skipping logging metrics") + get_monitor().log( + { + f"eval/{self.name}/failed_rollouts": failed_count / total_count, + "progress/ckpt_step": ckpt_step, + "step": step, + }, + step=step, + ) + return + + # Log metrics + monitor = get_monitor() + + rows = [ + { + "example_id": o["example_id"], + "reward": o["reward"], + "completion_len": get_completion_len(o), + "is_truncated": o["is_truncated"], + "has_error": o.get("error") is not None, + "no_response": not o.get("completion"), + } + for o in successful_outputs + ] + results_df = pd.DataFrame(rows) + + unique_rewards = results_df.reward.dropna().unique() + could_be_binary = set(unique_rewards).issubset({0.0, 1.0}) + if could_be_binary: + pass_at_k = ( + results_df.groupby("example_id") + .apply(lambda x: compute_pass_at_k(x.reward.dropna()), include_groups=False) + .apply(pd.Series) + ) + else: + pass_at_k = None + get_logger().warning("Skipping computing pass@k rates because the task rewards appear to be non-binary") + + message = f"Evaluated {self.name} in {eval_time:.2f}s (Avg@{k}={results_df.reward.mean():.4f}" + if could_be_binary: + for pass_rate, pass_rate_score in pd.Series(pass_at_k.mean()).items(): + message += f", {capitalize(str(pass_rate))}: {pass_rate_score:.4f}" + + message += ( + f", No-response: {results_df.no_response.mean() * 100:.1f}%" + f", Completion Length: {results_df.completion_len.mean():.2f} (±{results_df.completion_len.std():.2f}, ∈[{results_df.completion_len.min():.2f}, {results_df.completion_len.max():.2f}])" + f", Truncated: {results_df.is_truncated.mean() * 100:.1f}%)" + ) + get_logger().success(message) + + eval_metrics = { + f"avg@{k}": float(results_df.reward.mean()), + "no_response/mean": float(results_df.no_response.mean()), + "no_response/count": int(results_df.no_response.sum()), + "completion_len/mean": results_df.completion_len.mean().item(), + "completion_len/max": results_df.completion_len.max().item(), + "completion_len/min": results_df.completion_len.min().item(), + "is_truncated/mean": results_df.is_truncated.mean().item(), + "failed_rollouts": failed_count / total_count, + "time": eval_time, + } + if could_be_binary: + assert pass_at_k is not None + eval_metrics.update(pd.Series(pass_at_k.mean()).to_dict()) + eval_metrics = {f"eval/{self.name}/{k}": v for k, v in eval_metrics.items()} + eval_metrics.update({"progress/ckpt_step": ckpt_step, "step": step}) + monitor.log(eval_metrics, step=step) + monitor.log_eval_samples(successful_outputs, env_name=self.name, step=step) + + +class Envs: + """Base container for a set of Env instances.""" + + _envs: dict[str, Env] + + @property + def names(self) -> list[str]: + return list(self._envs.keys()) + + @property + def configs(self) -> list[EnvConfig]: + return [env.config for env in self._envs.values()] + + def get(self, name: str) -> Env: + return self._envs[name] + + def set_sampling_args(self, sampling_args: dict) -> None: + for env in self: + env.sampling_args = sampling_args + + def __iter__(self): + return iter(self._envs.values()) + + def __len__(self) -> int: + return len(self._envs) + + def spawn( + self, + log_dir: Path, + max_concurrent: int | None = None, + log_level: str | None = None, + json_logging: bool = False, + ) -> None: + """Spawn env servers for all envs without an explicit address.""" + for env in self: + env.spawn( + log_dir=log_dir, + max_concurrent=max_concurrent, + log_level=log_level, + json_logging=json_logging, + ) + atexit.register(self.shutdown) + + async def connect(self) -> None: + """Connect all env clients to their servers and wait for health (in parallel).""" + await asyncio.gather(*(env.connect() for env in self)) + + def shutdown(self) -> None: + """Terminate all spawned env server processes.""" + processes = [env._env_server_process for env in self if env._env_server_process is not None] + if not processes: + return + logger = get_logger() + logger.info(f"Shutting down {len(processes)} env server(s), waiting for sandbox cleanup...") + for env in self: + env.shutdown() + + +class TrainEnvs(Envs): + """Collection of training environments.""" + + def __init__(self, configs: Sequence[TrainEnvConfig]): + self._envs: dict[str, Env] = {} + for config in configs: + env = TrainEnv(config) + self._envs[env.name] = env + + +class EvalEnvs(Envs): + """Collection of evaluation environments.""" + + def __init__(self, configs: Sequence[EvalEnvConfig]): + self._envs: dict[str, Env] = {} + for config in configs: + env = EvalEnv(config) + self._envs[env.name] = env diff --git a/src/prime_rl/orchestrator/eval_utils.py b/src/prime_rl/orchestrator/eval_utils.py index 5c66e9ef36..75118d63fa 100644 --- a/src/prime_rl/orchestrator/eval_utils.py +++ b/src/prime_rl/orchestrator/eval_utils.py @@ -1,16 +1,4 @@ -import time -from collections.abc import Awaitable, Callable -from typing import Any - import numpy as np -import pandas as pd -import verifiers as vf - -from prime_rl.configs.orchestrator import EvalSamplingConfig -from prime_rl.orchestrator.vf_utils import evaluate, get_completion_len -from prime_rl.utils.logger import get_logger -from prime_rl.utils.monitor import get_monitor -from prime_rl.utils.utils import capitalize def compute_eval_ckpt_step( @@ -39,38 +27,6 @@ def compute_eval_ckpt_step( return None -def get_eval_sampling_args(sampling_config: EvalSamplingConfig) -> dict[str, Any]: - """Get sampling args for evaluation.""" - # Initialize sampling args - sampling_args: dict[str, Any] = {} - - # Apply sampling arguments, if specified - if sampling_config.temperature is not None: - sampling_args["temperature"] = sampling_config.temperature - if sampling_config.max_tokens is not None: - sampling_args["max_tokens"] = sampling_config.max_tokens - if sampling_config.top_p is not None: - sampling_args["top_p"] = sampling_config.top_p - if sampling_config.reasoning_effort is not None: - sampling_args["reasoning_effort"] = sampling_config.reasoning_effort - - extra_body: dict[str, Any] = sampling_config.extra_body.copy() - - # Apply vLLM-specific sampling arguments, if specified - if sampling_config.top_k is not None: - extra_body["top_k"] = sampling_config.top_k - if sampling_config.min_p is not None: - extra_body["min_p"] = sampling_config.min_p - if sampling_config.min_tokens is not None: - extra_body["min_tokens"] = sampling_config.min_tokens - if sampling_config.repetition_penalty is not None: - extra_body["repetition_penalty"] = sampling_config.repetition_penalty - - sampling_args["extra_body"] = extra_body - - return sampling_args - - def _pass_at_k(n: int, c: int, k: int) -> float: """Unbiased estimator of pass@k (Chen et al., 2021). @@ -86,101 +42,3 @@ def compute_pass_at_k(rewards: list[float]) -> dict[str, float]: c = sum(r == 1.0 for r in rewards) ks = [2**i for i in range(n.bit_length())] return {f"pass@{k}": _pass_at_k(n, c, k) for k in ks} - - -async def evaluate_env( - env: vf.Environment, - env_name: str, - model_name: str, - sampling_args: dict, - num_examples: int, - rollouts_per_example: int, - max_retries: int, - ckpt_step: int, - step: int, - get_client: Callable[[], Awaitable[vf.ClientConfig]], -): - logger = get_logger() - logger.info(f"Evaluating {env_name} ({num_examples=}, {rollouts_per_example=})") - eval_start_time = time.perf_counter() - total_inputs = len(env._get_eval_inputs(num_examples, rollouts_per_example)) - outputs = await evaluate( - env=env, - model_name=model_name, - sampling_args=sampling_args, - num_examples=num_examples, - rollouts_per_example=rollouts_per_example, - get_client=get_client, - max_retries=max_retries, - ) - eval_time = time.perf_counter() - eval_start_time - failed_rollouts = total_inputs - len(outputs) - - if not outputs: - logger.warning(f"All rollouts failed for {env_name} ({failed_rollouts} failed), skipping metrics") - monitor = get_monitor() - monitor.log( - {f"eval/{env_name}/failed_rollouts": failed_rollouts, "progress/ckpt_step": ckpt_step, "step": step}, - step=step, - ) - return - - rows = [] - for output in outputs: - rows.append( - { - "example_id": output["example_id"], - "reward": output["reward"], - "completion_len": get_completion_len(output), - "is_truncated": output["is_truncated"], - "has_error": output.get("error") is not None, - "no_response": not output.get("completion"), - } - ) - results_df = pd.DataFrame(rows) - - unique_rewards = results_df.reward.dropna().unique() - could_be_binary = set(unique_rewards).issubset({0.0, 1.0}) - if could_be_binary: - pass_at_k = ( - results_df.groupby("example_id") - .apply(lambda x: compute_pass_at_k(x.reward.dropna()), include_groups=False) - .apply(pd.Series) - ) - else: - pass_at_k = None - logger.warning("Skipping computing pass@k rates because the task rewards appear to be non-binary") - - # Log statistics to console - message = f"Evaluated {env_name} in {eval_time:.2f}s (Avg@{rollouts_per_example}={results_df.reward.mean():.4f}" - if could_be_binary: - assert pass_at_k is not None - for pass_rate, pass_rate_score in pd.Series(pass_at_k.mean()).items(): - message += f", {capitalize(str(pass_rate))}: {pass_rate_score:.4f}" - message += ( - f", No-response: {results_df.no_response.mean() * 100:.1f}%" - f", Completion Length: {results_df.completion_len.mean():.2f} (±{results_df.completion_len.std():.2f}, ∈[{results_df.completion_len.min():.2f}, {results_df.completion_len.max():.2f}])" - f", Truncated: {results_df.is_truncated.mean() * 100:.1f}%)" - ) - logger.success(message) - - # Log statistics to monitor - eval_metrics = { - f"avg@{rollouts_per_example}": float(results_df.reward.mean()), - "no_response/mean": float(results_df.no_response.mean()), - "no_response/count": int(results_df.no_response.sum()), - "completion_len/mean": results_df.completion_len.mean().item(), - "completion_len/max": results_df.completion_len.max().item(), - "completion_len/min": results_df.completion_len.min().item(), - "is_truncated/mean": results_df.is_truncated.mean().item(), - "failed_rollouts": failed_rollouts, - "time": eval_time, - } - if could_be_binary: - assert pass_at_k is not None - eval_metrics.update(pd.Series(pass_at_k.mean()).to_dict()) - eval_metrics = {**{f"eval/{env_name}/{k}": v for k, v in eval_metrics.items()}} - eval_metrics.update({"progress/ckpt_step": ckpt_step, "step": step}) - monitor = get_monitor() - monitor.log(eval_metrics, step=step) - monitor.log_eval_samples(outputs, env_name=env_name, step=step) diff --git a/src/prime_rl/orchestrator/orchestrator.py b/src/prime_rl/orchestrator/orchestrator.py index 5227d0358a..ffd47a0dd9 100644 --- a/src/prime_rl/orchestrator/orchestrator.py +++ b/src/prime_rl/orchestrator/orchestrator.py @@ -1,14 +1,12 @@ import asyncio -import atexit import gc -import multiprocessing as mp import time from concurrent.futures import ThreadPoolExecutor import tomli_w from prime_rl.orchestrator.advantage import compute_advantages -from prime_rl.orchestrator.eval_utils import compute_eval_ckpt_step, get_eval_sampling_args +from prime_rl.orchestrator.eval_utils import compute_eval_ckpt_step from prime_rl.orchestrator.event_loop_lag import EventLoopLagMonitor from prime_rl.orchestrator.patches import monkey_patch_chat_completion_logprobs, monkey_patch_oai_iterable_types from prime_rl.orchestrator.trajectories import ( @@ -33,30 +31,24 @@ import verifiers as vf from transformers import AutoProcessor, AutoTokenizer -from prime_rl.configs.orchestrator import BufferConfig, OrchestratorConfig -from prime_rl.orchestrator.buffer import Buffer +from prime_rl.configs.orchestrator import OrchestratorConfig +from prime_rl.orchestrator.buffer import BufferSet from prime_rl.orchestrator.ckpt import Progress, setup_ckpt_manager -from prime_rl.orchestrator.eval_utils import evaluate_env +from prime_rl.orchestrator.envs import EvalEnvs, TrainEnvs from prime_rl.orchestrator.filters import apply_filters, setup_filters from prime_rl.orchestrator.scheduler import Scheduler from prime_rl.orchestrator.utils import ( compute_teacher_logprobs, - get_sampling_args, + get_eval_sampling_args, + get_train_sampling_args, get_weight_dir, print_benchmark, - set_semaphore, setup_external_rollout_model, ) from prime_rl.orchestrator.vf_utils import ( - generate, get_completion_len, get_seq_len, intercept_vf_logging, - resolve_num_workers, - setup_env_client, - spawn_env_server, - task_uses_group_scoring, - wait_for_env_servers, ) from prime_rl.utils.client import ( init_nccl_broadcast, @@ -67,13 +59,11 @@ from prime_rl.utils.logger import setup_logger from prime_rl.utils.monitor import setup_monitor from prime_rl.utils.process import set_proc_title -from prime_rl.utils.temp_scheduling import compute_temperature from prime_rl.utils.utils import ( clean_exit, get_env_ids_to_install, install_env, resolve_latest_ckpt_step, - strip_env_version, to_col_format, ) @@ -170,156 +160,38 @@ async def orchestrate(config: OrchestratorConfig): if rollout_filters: logger.info(f"Initialized {len(rollout_filters)} rollout filter(s): {[f.name for f in rollout_filters]}") - # Load environment and extract dataset - logger.info( - f"Loading {len(config.env)} training environment(s) ({', '.join(env.resolved_name for env in config.env)})" - ) - env_ids = [strip_env_version(env.id) for env in config.env] - train_env_names = [env.resolved_name for env in config.env] - train_env_group = vf.EnvGroup( - envs=[vf.load_environment(env_id, **env.args) for env_id, env in zip(env_ids, config.env)], - env_names=train_env_names, - map_kwargs=dict(writer_batch_size=1), # set defensively to not error on map operations on large datasets - ) - verification_enabled = config.verification.enabled + # Load environments + logger.info("Loading training environments") + train_envs = TrainEnvs(config.train_envs) + train_envs.set_sampling_args(get_train_sampling_args(config.sampling, is_vllm=config.teacher_rollout_model is None)) + logger.info(f"Loaded {len(train_envs)} training environment(s) ({', '.join(train_envs.names)})") - train_env_deferred_group_scoring_tasks = ( - {env_name for env_name in train_env_names if task_uses_group_scoring(train_env_group, env_name)} - if verification_enabled - else set() + train_envs.spawn( + log_dir=get_log_dir(config.output_dir.parent) / "envs" / "train", + max_concurrent=config.max_inflight_rollouts, + log_level=config.log.vf_level, + json_logging=config.log.json_logging, ) - for train_env_name, env_cfg in zip(train_env_names, config.env): - env_cfg.extra_env_kwargs["score_rollouts"] = ( - verification_enabled and train_env_name not in train_env_deferred_group_scoring_tasks - ) - if not verification_enabled: - logger.info("Verification disabled; all training envs will skip scoring.") - elif train_env_deferred_group_scoring_tasks: - deferred_tasks = ", ".join(sorted(train_env_deferred_group_scoring_tasks)) - logger.info( - f"Deferred group scoring enabled for training tasks: {deferred_tasks}. " - "Rollouts run individually and are scored once each group completes." - ) - - train_env_addresses = [] - env_processes: list[mp.Process] = [] - - def _cleanup_env_processes(): - if not env_processes: - return - logger.info(f"Shutting down {len(env_processes)} env server(s), waiting for sandbox cleanup...") - for proc in env_processes: - proc.terminate() - for proc in env_processes: - proc.join(timeout=25) - if proc.is_alive(): - logger.warning(f"Env server {proc.pid} did not exit after 25s, force killing") - proc.kill() - proc.join(timeout=5) - - atexit.register(_cleanup_env_processes) - - for env_id, env, env_name in zip(env_ids, config.env, train_env_names): - if env.address is None: - num_workers = resolve_num_workers(env.num_workers, config.max_inflight_rollouts) - log_dir = (get_log_dir(config.output_dir.parent) / "envs" / "train" / env_name).as_posix() - address, process = spawn_env_server( - env_id=env_id, - env_args=env.args, - extra_env_kwargs=env.extra_env_kwargs, - num_workers=num_workers, - log_level=config.log.vf_level, - log_dir=log_dir, - json_logging=config.log.json_logging, - ) - logger.info(f"Spawned env server for {env_name} with {num_workers} worker(s)") - env_processes.append(process) - else: - if env_name in train_env_deferred_group_scoring_tasks: - logger.warning( - f"Training env {env_name} uses external server at {env.address}. " - "Ensure that server was started with score_rollouts=False." - ) - address = env.address - logger.info(f"Connecting train environment {env_name} to server at {address}") - train_env_addresses.append(address) - train_env_clients = [ - setup_env_client(address=address, name=name) for name, address in zip(train_env_names, train_env_addresses) - ] - - logger.info("Waiting for train environment servers to be ready") - await wait_for_env_servers(train_env_clients) - logger.success("Train environment servers ready") - - # this puts all train envs into server model - # all calls to run_rollout and run_group will be routed to the server via the env client - for env, env_client in zip(train_env_group.envs, train_env_clients): - env.env_client = env_client + await train_envs.connect() + logger.success("Train environments ready") if config.eval: - env_ids = [strip_env_version(env.id) for env in config.eval.env] - eval_envs = [vf.load_environment(env_id, **env.args) for env_id, env in zip(env_ids, config.eval.env)] - eval_env_names = [env.resolved_name for env in config.eval.env] - eval_sampling_args = get_eval_sampling_args(config.eval.sampling) - eval_env_addresses = [] - - for env_id, env, eval_env_name in zip(env_ids, config.eval.env, eval_env_names): - if env.address is None: - num_examples = env.num_examples or config.eval.num_examples - rollouts_per_example = env.rollouts_per_example or config.eval.rollouts_per_example - if num_examples == -1: - max_concurrent = 1024 - logger.warning( - f"Eval env '{eval_env_name}' uses all examples (num_examples=-1). " - f"Defaulting max_concurrent={max_concurrent} for worker scaling." - ) - else: - max_concurrent = num_examples * rollouts_per_example - num_workers = resolve_num_workers(env.num_workers, max_concurrent) - log_dir = (get_log_dir(config.output_dir.parent) / "envs" / "eval" / eval_env_name).as_posix() - address, process = spawn_env_server( - env_id=env_id, - env_args=env.args, - extra_env_kwargs=env.extra_env_kwargs, - num_workers=num_workers, - log_level=config.log.vf_level, - log_dir=log_dir, - json_logging=config.log.json_logging, - ) - logger.info(f"Spawned eval env server for {eval_env_name} with {num_workers} worker(s)") - env_processes.append(process) - else: - address = env.address - logger.info(f"Connecting eval environment {eval_env_name} to server at {address}") - eval_env_addresses.append(address) - - eval_env_clients = [ - setup_env_client(address=address, name=name) for name, address in zip(eval_env_names, eval_env_addresses) - ] - - logger.info("Waiting for eval environment servers to be ready") - await wait_for_env_servers(eval_env_clients) - logger.success("Eval environment servers ready") - - # this puts all eval envs into server mode - # all calls to run_rollout and run_group will be routed to the server via the env client - for eval_env, eval_env_client in zip(eval_envs, eval_env_clients): - eval_env.env_client = eval_env_client - else: - eval_envs: list[vf.Environment] = [] - eval_env_names: list[str] = [] - eval_sampling_args = {} + logger.info("Loading eval environments") + eval_envs = EvalEnvs(config.eval_envs) + eval_envs.set_sampling_args(get_eval_sampling_args(config.eval.sampling)) + logger.info(f"Loaded {len(eval_envs)} eval environment(s) ({', '.join(eval_envs.names)})") + + eval_envs.spawn( + log_dir=get_log_dir(config.output_dir.parent) / "envs" / "eval", + log_level=config.log.vf_level, + json_logging=config.log.json_logging, + ) + await eval_envs.connect() + logger.success("Eval environments ready") # Setup buffer logger.info(f"Setting up buffer ({config.buffer})") - train_dataset = train_env_group.get_dataset(seed=config.buffer.seed) - buffer = Buffer(train_dataset, train_env_group.env_names, config.buffer) - if config.val is not None: - val_buffer_config = BufferConfig(env_ratios=config.buffer.env_ratios) - val_dataset = train_env_group.get_eval_dataset(seed=val_buffer_config.seed) - val_buffer = Buffer(val_dataset, train_env_group.env_names, val_buffer_config) - else: - val_buffer = None + buffer = BufferSet(train_envs, config.buffer) # Get checkpoint manager logger.info(f"Initializing checkpoint manager ({config.ckpt})") @@ -333,7 +205,7 @@ def _cleanup_env_processes(): checkpoint_step = config.ckpt.resume_step scheduler = Scheduler( - env=train_env_group, + envs=train_envs, buffer=buffer, inference_pool=inference_pool, max_inflight_rollouts=config.max_inflight_rollouts, @@ -343,7 +215,6 @@ def _cleanup_env_processes(): tasks_per_minute=config.tasks_per_minute, enable_policy_updates=enable_policy_updates, lora_name=config.model.lora.name if config.model.lora else None, - deferred_group_scoring_tasks=train_env_deferred_group_scoring_tasks, config=config, ) scheduler.model_name = rollout_model_name @@ -416,10 +287,8 @@ def _cleanup_env_processes(): logger.info("Training from scratch") # Iterate over dataset in batches - max_steps = config.max_steps or int(1e9) - logger.info(f"Starting orchestrator loop (max_steps={max_steps or 'infinite'})") + logger.info(f"Starting orchestrator loop (max_steps={config.max_steps or 'infinite'})") is_first_step = True - await set_semaphore(config.max_concurrent or -1) # Persistent ThreadPoolExecutor for parallel rollout processing rollout_executor = ThreadPoolExecutor(max_workers=64) @@ -469,7 +338,6 @@ def _cleanup_env_processes(): eval_base_model=config.eval.eval_base_model, ) - if eval_ckpt_step is not None: last_eval_step = ckpt_step if eval_ckpt_step != ckpt_step: logger.info(f"Running evals for interval step {eval_ckpt_step} (current ckpt_step={ckpt_step})") @@ -485,21 +353,15 @@ def _cleanup_env_processes(): logger.info("Cancelling in-flight training rollouts before starting evals to avoid congestion.") await scheduler.cancel_inflight_rollouts() - results = await asyncio.gather( + await asyncio.gather( *[ - evaluate_env( - env=eval_env, - env_name=eval_env_name, - get_client=inference_pool.get_next_client, + eval_env.evaluate( model_name=scheduler.model_name, - sampling_args=eval_sampling_args, - num_examples=eval_env_config.num_examples or config.eval.num_examples, - rollouts_per_example=eval_env_config.rollouts_per_example or config.eval.rollouts_per_example, - max_retries=eval_env_config.max_retries, + get_client=inference_pool.get_next_client, ckpt_step=ckpt_step, step=progress.step, ) - for eval_env, eval_env_name, eval_env_config in zip(eval_envs, eval_env_names, config.eval.env) + for eval_env in eval_envs ] ) @@ -510,30 +372,8 @@ def _cleanup_env_processes(): prev_ckpt_step = ckpt_step # Schedule generating the training batch - temperature = compute_temperature(progress.step, config.sampling, config.max_steps) - is_vllm = config.teacher_rollout_model is None - sampling_args = get_sampling_args(config.sampling, temperature=temperature, is_vllm=is_vllm) - scheduler.set_sampling_args(sampling_args) train_task = asyncio.create_task(scheduler.generate_batch(step=progress.step)) - # Schedule running validation at the specified interval - if val_buffer and config.val and progress.step % config.val.interval == 0: - logger.info(f"Running validation for step {progress.step}") - val_examples = val_buffer.sample_examples(config.val.num_examples) - val_task = asyncio.create_task( - generate( - env=train_env_group, - model_name=scheduler.model_name, - examples=val_examples, - rollouts_per_example=config.val.rollouts_per_example, - sampling_args=sampling_args, - clients=inference_pool.clients, - pbar_description="Generating rollouts (val)", - ) - ) - else: - val_task = asyncio.create_task(asyncio.sleep(0)) # Dummy task - # Await train rollouts await train_task generate_completions_time = scheduler.last_batch_generation_time @@ -651,17 +491,13 @@ def process_rollout(rollout: vf.RolloutOutput, rollout_idx: int) -> list[Trainin training_batch_sender.send(training_batch) - # Await and process val results - await val_task - val_outputs = val_task.result() - step_time = time.perf_counter() - step_start_time # Gather metrics in dataframes results_df = pd.DataFrame( { "example_id": [rollout["example_id"] for rollout in train_rollouts], - "task": [rollout["task"] for rollout in train_rollouts], + "env_name": [rollout["env_name"] for rollout in train_rollouts], "reward": [rollout["reward"] for rollout in train_rollouts], "is_truncated": [rollout["is_truncated"] for rollout in train_rollouts], "stop_condition": [rollout.get("stop_condition") for rollout in train_rollouts], @@ -678,18 +514,6 @@ def process_rollout(rollout: vf.RolloutOutput, rollout_idx: int) -> list[Trainin # Separate DataFrame for env reward function metrics to avoid column name collisions metrics_df = pd.DataFrame([rollout["metrics"] for rollout in train_rollouts]) - val_results_df = ( - pd.DataFrame( - { - "example_id": [rollout["example_id"] for rollout in val_outputs], - "task": [rollout["task"] for rollout in val_outputs], - "reward": [rollout["reward"] for rollout in val_outputs], - } - ) - if val_outputs is not None - else None - ) - # Update progress metrics num_tokens = int(results_df.seq_len.sum()) progress.total_tokens += num_tokens @@ -753,12 +577,12 @@ def compute_solve_rates(df): "reward/all/mean": by_example.reward.mean().mean(), "reward/all/max": by_example.reward.mean().max(), "reward/all/min": by_example.reward.mean().min(), - "sampling/temperature": temperature, + "sampling/temperature": config.sampling.temperature, # Solve / batch metrics "solve_none/all": solve_none, "solve_all/all": solve_all, "effective_batch_size/all": effective_batch_size, - **{f"batch/{env}": r for env, r in results_df.task.value_counts(normalize=True).items()}, + **{f"batch/{env}": r for env, r in results_df.env_name.value_counts(normalize=True).items()}, # Time metrics "time/step": step_time, "time/generate_completions": generate_completions_time, @@ -789,7 +613,7 @@ def compute_solve_rates(df): "scoring_ms", ] - for env, env_df in results_df.groupby("task"): + for env, env_df in results_df.groupby("env_name"): env_by_example = env_df.groupby("example_id") for col in per_env_columns: to_log[f"{col}/{env}/mean"] = env_by_example[col].mean().mean() @@ -812,18 +636,6 @@ def compute_solve_rates(df): for metric in metrics_df.columns: to_log[f"metrics/{env}/{metric}"] = env_metrics_df.groupby(env_df["example_id"])[metric].mean().mean() - # Optionally, add val metrics - if val_results_df is not None: - val_by_example = val_results_df.groupby("example_id") - to_log["val/reward/all/mean"] = val_by_example.reward.mean().mean() - to_log["val/reward/all/max"] = val_by_example.reward.mean().max() - to_log["val/reward/all/min"] = val_by_example.reward.mean().min() - for env, env_df in val_results_df.groupby("task"): - env_by_example = env_df.groupby("example_id") - to_log[f"val/reward/{env}/mean"] = env_by_example.reward.mean().mean() - to_log[f"val/reward/{env}/max"] = env_by_example.reward.mean().max() - to_log[f"val/reward/{env}/min"] = env_by_example.reward.mean().min() - # Log metrics to monitor(s) monitor.log(to_log, step=progress.step) @@ -840,11 +652,7 @@ def compute_solve_rates(df): ) reward_mean = by_example.reward.mean().mean() - val_reward_str = "" - if val_results_df is not None: - val_reward_mean = val_results_df.groupby("example_id").reward.mean().mean() - val_reward_str = f" Val. Reward: {val_reward_mean:.4f} |" - step_message = f"Step {progress.step} | Time: {step_time:.2f}s | Reward: {reward_mean:.4f} |{val_reward_str} Seq. Length: {by_example.seq_len.mean().mean():.1f} tokens/sample | Async Level: {scheduler.async_level} | Max. Off-Policy Level: {scheduler.max_off_policy_level}" + step_message = f"Step {progress.step} | Time: {step_time:.2f}s | Reward: {reward_mean:.4f} | Seq. Length: {by_example.seq_len.mean().mean():.1f} tokens/sample | Async Level: {scheduler.async_level} | Max. Off-Policy Level: {scheduler.max_off_policy_level}" logger.success(step_message) # Increment step @@ -853,7 +661,7 @@ def compute_solve_rates(df): # Free large per-step objects to prevent memory accumulation del train_rollouts, train_examples, training_batch, vlm_cache - del results_df, metrics_df, val_results_df + del results_df, metrics_df gc.collect() event_loop_lag_monitor.reset() @@ -862,23 +670,17 @@ def compute_solve_rates(df): if heart is not None: heart.beat() - if config.eval: + if config.eval and eval_envs is not None: logger.info("Running final evals") - results = await asyncio.gather( + await asyncio.gather( *[ - evaluate_env( - env=eval_env, - env_name=eval_env_name, - get_client=inference_pool.get_next_client, + eval_env.evaluate( model_name=scheduler.model_name, - sampling_args=eval_sampling_args, - num_examples=eval_env_config.num_examples or config.eval.num_examples, - rollouts_per_example=eval_env_config.rollouts_per_example or config.eval.rollouts_per_example, - max_retries=eval_env_config.max_retries, + get_client=inference_pool.get_next_client, ckpt_step=ckpt_step, step=progress.step, ) - for eval_env, eval_env_name, eval_env_config in zip(eval_envs, eval_env_names, config.eval.env) + for eval_env in eval_envs ] ) @@ -910,8 +712,9 @@ def compute_solve_rates(df): event_loop_lag_monitor_task.cancel() # Shutdown env processes (also registered as atexit handler for crash safety) - atexit.unregister(_cleanup_env_processes) - _cleanup_env_processes() + train_envs.shutdown() + if eval_envs is not None: + eval_envs.shutdown() logger.success("Orchestrator finished.") diff --git a/src/prime_rl/orchestrator/scheduler.py b/src/prime_rl/orchestrator/scheduler.py index 1ee154d935..4531a1fdf6 100644 --- a/src/prime_rl/orchestrator/scheduler.py +++ b/src/prime_rl/orchestrator/scheduler.py @@ -4,19 +4,18 @@ import time from collections import Counter, defaultdict from dataclasses import dataclass, field -from typing import NamedTuple, cast +from typing import NamedTuple import verifiers as vf from aiolimiter import AsyncLimiter from prime_rl.configs.orchestrator import OrchestratorConfig -from prime_rl.orchestrator.buffer import Buffer -from prime_rl.orchestrator.utils import get_sampling_args -from prime_rl.orchestrator.vf_utils import get_seq_len, run_rollout +from prime_rl.orchestrator.buffer import BufferSet as Buffer +from prime_rl.orchestrator.envs import Envs +from prime_rl.orchestrator.vf_utils import get_seq_len from prime_rl.utils.async_utils import safe_cancel, safe_cancel_all from prime_rl.utils.client import InferencePool from prime_rl.utils.logger import ProgressTracker, get_logger -from prime_rl.utils.temp_scheduling import compute_temperature from prime_rl.utils.utils import ( get_broadcast_dir, get_latest_ckpt_step, @@ -30,7 +29,7 @@ class InflightRolloutInfo(NamedTuple): off_policy_steps: int client_config: vf.ClientConfig - task: str + env_name: str group_id: int | None = None @@ -57,7 +56,7 @@ class Scheduler: def __init__( self, - env: vf.Environment, + envs: Envs, inference_pool: InferencePool, buffer: Buffer, config: OrchestratorConfig, @@ -68,14 +67,13 @@ def __init__( tasks_per_minute: int | None, enable_policy_updates: bool = True, lora_name: str | None = None, - deferred_group_scoring_tasks: set[str] | None = None, ): self.logger = get_logger() if tasks_per_minute is not None: self.rate_limiter = AsyncLimiter(max_rate=tasks_per_minute, time_period=60) else: self.rate_limiter = None - self.env = env + self.envs = envs self.buffer = buffer self.config = config self.batch_size = config.batch_size @@ -87,20 +85,15 @@ def __init__( self.strict_async_level = strict_async_level self.enable_policy_updates = enable_policy_updates self.lora_name = lora_name - initial_temp = compute_temperature(step=0, sampling_config=config.sampling, max_steps=config.max_steps) - is_vllm = config.teacher_rollout_model is None - self.sampling_args = get_sampling_args(config.sampling, temperature=initial_temp, is_vllm=is_vllm) self.model_name = self.config.model.name self.json_logging = config.log.json_logging # Inference pool - used for admin operations (adapter sync) and metrics self.inference_pool = inference_pool - self.max_retries_by_task = {env.resolved_name: env.max_retries for env in config.env} - self.deferred_group_scoring_tasks = set(deferred_group_scoring_tasks or ()) - if self.deferred_group_scoring_tasks: - task_list = ", ".join(sorted(self.deferred_group_scoring_tasks)) - self.logger.info(f"Deferred group scoring active for task(s): {task_list}") + group_scoring_tasks = [env.name for env in envs if env.requires_group_scoring] + if group_scoring_tasks: + self.logger.info(f"Group rollout scoring active for task(s): {', '.join(group_scoring_tasks)}") # Track in-flight requests: task -> info self.inflight_requests: dict[asyncio.Task, InflightRolloutInfo] = {} @@ -144,10 +137,6 @@ def finalize_batch_rollouts(self, rollouts: list[vf.RolloutOutput]) -> list[vf.R return rollouts return rollouts[: self.batch_size] - def set_sampling_args(self, sampling_args: dict) -> None: - """Update sampling args for future rollout requests.""" - self.sampling_args = sampling_args - async def cancel_inflight_rollouts(self): """Cancel all in-flight rollout requests.""" count = len(self.inflight_requests) @@ -186,13 +175,13 @@ async def drop_group(self, group_id: int) -> int: return len(tasks_to_cancel) async def schedule_rollout(self, group_id: int): - """Asynchronously schedules a rollout request.""" + """Asynchronously schedules a rollout request (or a group request for group-scoring envs).""" if self.rate_limiter: await self.rate_limiter.acquire() group = self.groups.get(group_id) if group is None or group.rollouts_to_schedule <= 0: return - group.rollouts_to_schedule -= 1 + if group.pinned_client is not None: client_config = group.pinned_client else: @@ -200,18 +189,31 @@ async def schedule_rollout(self, group_id: int): if group_id not in self.groups: return group.pinned_client = client_config - run_rollout_task = asyncio.create_task( - run_rollout( - env=self.env, - client=client_config, - example=group.example, - model_name=self.model_name, - sampling_args=self.sampling_args, - max_retries=self.max_retries_by_task.get(group.example["task"], 0), + + env_name = group.example["env_name"] + env = self.envs.get(env_name) + + if env.requires_group_scoring: + group.rollouts_to_schedule = 0 + task = asyncio.create_task( + env.run_group( + client=client_config, + example=group.example, + model_name=self.model_name, + rollouts_per_example=self.rollouts_per_example, + ) ) - ) - self.inflight_requests[run_rollout_task] = InflightRolloutInfo( - off_policy_steps=0, client_config=client_config, task=group.example["task"], group_id=group_id + else: + group.rollouts_to_schedule -= 1 + task = asyncio.create_task( + env.run_rollout( + client=client_config, + example=group.example, + model_name=self.model_name, + ) + ) + self.inflight_requests[task] = InflightRolloutInfo( + off_policy_steps=0, client_config=client_config, env_name=env_name, group_id=group_id ) @property @@ -348,19 +350,6 @@ async def _update_off_policy(self) -> None: f"Consider increasing max_off_policy_steps to avoid this." ) - def _should_defer_group_scoring(self, task: str) -> bool: - return task in self.deferred_group_scoring_tasks and self.config.verification.enabled - - async def _score_group_if_deferred(self, completed_rollouts: list[vf.RolloutOutput]) -> list[vf.RolloutOutput]: - if not completed_rollouts: - return completed_rollouts - task = completed_rollouts[0]["task"] - if not self._should_defer_group_scoring(task): - return completed_rollouts - env_for_task = self.env.get_env_for_task(task) - await env_for_task.rubric.score_group(cast(list[vf.State], completed_rollouts)) - return completed_rollouts - async def generate_batch(self, step: int) -> list[vf.RolloutOutput]: """Continuously generates a batch of rollouts.""" self.step = step @@ -407,40 +396,51 @@ async def generate_batch(self, step: int) -> list[vf.RolloutOutput]: continue group_id = rollout_info.group_id + env_name = rollout_info.env_name try: group = self.groups.get(group_id) if group is None: continue - rollout = finished_task.result() - - task = rollout_info.task - self.total_rollouts_by_task[task] += 1 - should_reschedule = False - if len(rollout["trajectory"]) == 0: - self.empty_rollouts_by_task[task] += 1 - should_reschedule = True - self.logger.warning( - f"Empty trajectory in group {group_id} ({task}), re-scheduling " - f"({len(group.completed_rollouts)}/{self.rollouts_per_example} complete)" - ) - if rollout["error"] is not None: - self.errored_rollouts_by_task[task] += 1 - should_reschedule = True - self.logger.warning( - f"Rollout error in group {group_id} ({task}), re-scheduling " - f"({len(group.completed_rollouts)}/{self.rollouts_per_example} complete): " - f"{rollout['error']['error_chain_repr']}" - ) - if should_reschedule: - group.rollouts_to_schedule += 1 - continue - group.completed_rollouts.append(rollout) - if len(group.completed_rollouts) < self.rollouts_per_example: - continue - completed_rollouts = self.groups.pop(group_id).completed_rollouts - completed_rollouts = await self._score_group_if_deferred(completed_rollouts) + env = self.envs.get(env_name) + if env.requires_group_scoring: + # run_group returns all rollouts at once, already scored + group_rollouts: list[vf.RolloutOutput] = finished_task.result() + self.total_rollouts_by_task[env_name] += len(group_rollouts) + for rollout in group_rollouts: + rollout["env_name"] = env_name + completed_rollouts = group_rollouts + self.groups.pop(group_id, None) + else: + rollout = finished_task.result() + self.total_rollouts_by_task[env_name] += 1 + should_reschedule = False + if len(rollout["trajectory"]) == 0: + self.empty_rollouts_by_task[env_name] += 1 + should_reschedule = True + self.logger.warning( + f"Empty trajectory in group {group_id} ({env_name}), re-scheduling " + f"({len(group.completed_rollouts)}/{self.rollouts_per_example} complete)" + ) + if rollout["error"] is not None: + self.errored_rollouts_by_task[env_name] += 1 + should_reschedule = True + self.logger.warning( + f"Rollout error in group {group_id} ({env_name}), re-scheduling " + f"({len(group.completed_rollouts)}/{self.rollouts_per_example} complete): " + f"{rollout['error']['error_chain_repr']}" + ) + if should_reschedule: + group.rollouts_to_schedule += 1 + continue + + rollout["env_name"] = env_name + group.completed_rollouts.append(rollout) + if len(group.completed_rollouts) < self.rollouts_per_example: + continue + completed_rollouts = self.groups.pop(group_id).completed_rollouts + except asyncio.CancelledError: if group_id is not None: await self.drop_group(group_id) @@ -513,7 +513,7 @@ def get_metrics(self) -> dict[str, float]: metrics[f"errored_rollouts/{task}"] = self.errored_rollouts_by_task.get(task, 0) / task_total by_task: dict[str, list[int]] = {} for info in self.inflight_requests.values(): - by_task.setdefault(info.task, []).append(info.off_policy_steps) + by_task.setdefault(info.env_name, []).append(info.off_policy_steps) for task, steps in by_task.items(): metrics[f"off_policy_level/{task}/max"] = max(steps) metrics[f"off_policy_level/{task}/mean"] = sum(steps) / len(steps) diff --git a/src/prime_rl/orchestrator/utils.py b/src/prime_rl/orchestrator/utils.py index 5c6f56a82f..00809b08e7 100644 --- a/src/prime_rl/orchestrator/utils.py +++ b/src/prime_rl/orchestrator/utils.py @@ -2,7 +2,7 @@ import time from itertools import cycle from pathlib import Path -from typing import Any, AsyncContextManager +from typing import Any import pandas as pd import verifiers as vf @@ -10,10 +10,9 @@ from openai.types.completion_usage import CompletionUsage from rich.console import Console from rich.table import Table -from verifiers.utils.async_utils import maybe_semaphore from verifiers.utils.client_utils import setup_openai_client -from prime_rl.configs.orchestrator import OrchestratorConfig, SamplingConfig +from prime_rl.configs.orchestrator import EvalSamplingConfig, OrchestratorConfig, SamplingConfig from prime_rl.transport import TrainingSample from prime_rl.utils.utils import ( format_time, @@ -22,26 +21,12 @@ get_step_path, ) -SEMAPHORE: AsyncContextManager | None = None - -async def set_semaphore(limit: int): - global SEMAPHORE - SEMAPHORE = await maybe_semaphore(limit) - - -async def get_semaphore() -> AsyncContextManager: - global SEMAPHORE - assert SEMAPHORE is not None, "Semaphore not set" - return SEMAPHORE - - -def get_sampling_args(sampling_config: SamplingConfig, temperature: float, is_vllm: bool = True) -> dict: +def get_train_sampling_args(sampling_config: SamplingConfig, is_vllm: bool = True) -> dict: # Convert SamplingConfig to vLLM OAI sampling args # https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#extra-parameters_2 sampling_args = dict(sampling_config) sampling_args.pop("temp_scheduler", None) - sampling_args["temperature"] = temperature sampling_args["top_p"] = 1.0 sampling_args["logprobs"] = True extra_body = dict(sampling_config.extra_body) @@ -65,6 +50,35 @@ def get_sampling_args(sampling_config: SamplingConfig, temperature: float, is_vl return sampling_args +def get_eval_sampling_args(sampling_config: EvalSamplingConfig) -> dict[str, Any]: + """Get sampling args for evaluation.""" + sampling_args: dict[str, Any] = {} + + if sampling_config.temperature is not None: + sampling_args["temperature"] = sampling_config.temperature + if sampling_config.max_tokens is not None: + sampling_args["max_tokens"] = sampling_config.max_tokens + if sampling_config.top_p is not None: + sampling_args["top_p"] = sampling_config.top_p + if sampling_config.reasoning_effort is not None: + sampling_args["reasoning_effort"] = sampling_config.reasoning_effort + + extra_body: dict[str, Any] = sampling_config.extra_body.copy() + + if sampling_config.top_k is not None: + extra_body["top_k"] = sampling_config.top_k + if sampling_config.min_p is not None: + extra_body["min_p"] = sampling_config.min_p + if sampling_config.min_tokens is not None: + extra_body["min_tokens"] = sampling_config.min_tokens + if sampling_config.repetition_penalty is not None: + extra_body["repetition_penalty"] = sampling_config.repetition_penalty + + sampling_args["extra_body"] = extra_body + + return sampling_args + + def parse_num_completion_tokens(responses: list[list[ChatCompletion]]) -> list[int]: """Parses the number of tokens from a list of chat completions returned by OAI API.""" all_num_completion_tokens = [] @@ -159,21 +173,20 @@ async def compute_teacher_logprobs( async def _compute_single(client_config: vf.ClientConfig, sample: TrainingSample) -> list[float]: client = setup_openai_client(client_config) - async with await get_semaphore(): - response = await client.post( - "/chat/completions/tokens", - body={ - "model": model_name, - "messages": [{"role": "user", "content": ""}], - "tokens": sample.prompt_ids + sample.completion_ids, - "max_tokens": 1, - "temperature": 1.0, - "top_p": 1.0, - "skip_special_tokens": False, - "prompt_logprobs": True, - }, - cast_to=ChatCompletion, - ) + response = await client.post( + "/chat/completions/tokens", + body={ + "model": model_name, + "messages": [{"role": "user", "content": ""}], + "tokens": sample.prompt_ids + sample.completion_ids, + "max_tokens": 1, + "temperature": 1.0, + "top_p": 1.0, + "skip_special_tokens": False, + "prompt_logprobs": True, + }, + cast_to=ChatCompletion, + ) return [ 0.0 if lp is None else float(next(iter(lp.values()))["logprob"]) for lp in getattr(response, "prompt_logprobs", []) diff --git a/src/prime_rl/orchestrator/vf_utils.py b/src/prime_rl/orchestrator/vf_utils.py index 7c20deb171..f329755165 100644 --- a/src/prime_rl/orchestrator/vf_utils.py +++ b/src/prime_rl/orchestrator/vf_utils.py @@ -1,21 +1,9 @@ -import asyncio import logging import math -import multiprocessing as mp -from collections.abc import Awaitable, Callable -from itertools import cycle -from typing import Any import verifiers as vf -from verifiers.serve import EnvClient, ZMQEnvClient, ZMQEnvServer -from verifiers.utils.serve_utils import get_free_port - -from prime_rl.utils.logger import InterceptHandler, ProgressTracker, get_logger - -DEFAULT_RETRIES = 0 -REQUIRED_STATE_COLUMNS = ["trajectory", "sampling_args"] -DEFAULT_STATE_COLUMNS = [] +from prime_rl.utils.logger import InterceptHandler WORKERS_PER_CONCURRENCY = 256 @@ -32,224 +20,6 @@ def resolve_num_workers(num_workers: int | str, max_concurrent: int | None = Non return int(num_workers) -def spawn_env_server( - env_id: str, - env_args: dict[str, Any], - extra_env_kwargs: dict[str, Any], - address: str | None = None, - num_workers: int = 1, - # logging configs - log_level: str | None = None, - log_dir: str | None = None, - json_logging: bool = False, -) -> tuple[str, mp.Process]: - """ - Starts a ZMQEnvServer process in a subprocess. - - Mirrors vf.Environment.start_server(). - """ - address = address or f"tcp://127.0.0.1:{get_free_port()}" - # Use spawn to avoid inheriting file descriptors (e.g. sockets) from - # the parent process, which has caused hangs when multiple env server - # subprocesses share the same fds. - process = mp.get_context("spawn").Process( - target=ZMQEnvServer.run_server, - args=( - env_id, - env_args, - extra_env_kwargs, - log_level, - log_dir, - ), - kwargs=dict( - address=address, - json_logging=json_logging, - console_logging=False, - num_workers=num_workers, - ), - daemon=False, # cannot run daemon because env server uses subprocesses - ) - process.start() - - return address, process - - -def setup_env_client( - address: str, - name: str | None = None, - # health check configs - health_check_interval: float = 5.0, # 5s (we detect an env server as unhealth after 3 * 5s = 15s of unsuccessful health checks) - startup_timeout: float = 600.0, # 10m - recovery_timeout: float = 600.0, # 10m -) -> EnvClient: - """Sets up a ZMQEnvClient for a given address.""" - return ZMQEnvClient( - address=address, - name=name, - health_check_interval=health_check_interval, - startup_timeout=startup_timeout, - recovery_timeout=recovery_timeout, - ) - - -async def wait_for_env_servers(env_clients: list[EnvClient]) -> None: - await asyncio.gather(*[env_client.wait_for_server_startup() for env_client in env_clients]) - - -async def run_rollout( - env: vf.Environment, - client: vf.ClientConfig, - model_name: str, - example: dict, - sampling_args: dict, - max_retries: int = DEFAULT_RETRIES, - state_columns: list[str] = DEFAULT_STATE_COLUMNS, -) -> vf.RolloutOutput: - """ - Wrapper for vf.Environment.run_rollout(). - - Asynchronously generates and scores one rollout. - """ - state_columns = state_columns + REQUIRED_STATE_COLUMNS - rollout_input = vf.RolloutInput(**example) - return await env.run_rollout( - rollout_input, - client=client, - model=model_name, - sampling_args=sampling_args, - max_retries=max_retries, - state_columns=state_columns, - ) - - -async def run_group( - env: vf.Environment, - client: vf.ClientConfig, - model_name: str, - example: dict, - rollouts_per_example: int, - sampling_args: dict, - max_retries: int = DEFAULT_RETRIES, - state_columns: list[str] = DEFAULT_STATE_COLUMNS, -) -> list[vf.RolloutOutput]: - """ - Wrapper for vf.Environment.run_group(). - - Asynchronously generates and scores a group. - """ - state_columns = state_columns + REQUIRED_STATE_COLUMNS - group_inputs = [vf.RolloutInput(**example) for _ in range(rollouts_per_example)] - return await env.run_group( - group_inputs, - client=client, - model=model_name, - sampling_args=sampling_args, - max_retries=max_retries, - state_columns=state_columns, - ) - - -# TODO: migrate this to vf.Environment.generate() once it supports multiple clients -async def generate( - env: vf.Environment, - model_name: str, - examples: list, - rollouts_per_example: int, - sampling_args: dict, - clients: list[vf.ClientConfig] | None = None, - get_client: Callable[[], Awaitable[vf.ClientConfig]] | None = None, - max_retries: int = DEFAULT_RETRIES, - state_columns: list[str] = DEFAULT_STATE_COLUMNS, - pbar_description: str = "Generating rollouts", -) -> list[vf.RolloutOutput]: - """ - Wrapper for vf.Environment.generate(). - - NOTE: Currently we cannot use vf.Environment.generate() directly because it does not support multiple clients. - - Asynchronously generates and scores a list of groups. - """ - - if not clients and get_client is None: - raise ValueError("generate requires at least one client or a get_client callback") - - if get_client is None: - client_cycle = cycle(clients) - - async def get_client() -> vf.ClientConfig: - return next(client_cycle) - - total_rollouts = len(examples) * rollouts_per_example - pbar = ProgressTracker(total=total_rollouts, desc=pbar_description) - - async def run_group_with_progress(example) -> list[vf.RolloutOutput] | None: - try: - client = await get_client() - result = await run_group( - env=env, - client=client, - model_name=model_name, - example=example, - rollouts_per_example=rollouts_per_example, - max_retries=max_retries, - state_columns=state_columns, - sampling_args=sampling_args, - ) - pbar.update(rollouts_per_example) - return result - except Exception as e: - get_logger().warning(f"Group failed: {e}") - pbar.update(rollouts_per_example) - return None - - try: - group_outputs_list = await asyncio.gather(*[run_group_with_progress(example) for example in examples]) - finally: - pbar.close() - - failed_groups = sum(1 for g in group_outputs_list if g is None) - if failed_groups: - get_logger().warning(f"{failed_groups}/{len(group_outputs_list)} groups failed") - - return [output for group_outputs in group_outputs_list if group_outputs is not None for output in group_outputs] - - -async def evaluate( - env: vf.Environment, - model_name: str, - sampling_args: dict, - num_examples: int, - rollouts_per_example: int, - clients: list[vf.ClientConfig] | None = None, - get_client: Callable[[], Awaitable[vf.ClientConfig]] | None = None, - max_retries: int = DEFAULT_RETRIES, - state_columns: list[str] = DEFAULT_STATE_COLUMNS, -) -> list[vf.RolloutOutput]: - """ - Wrapper for vf.Environment.evaluate(). - - NOTE: Currently we cannot use vf.Environment.evaluate() directly because it does not support multiple clients. - Instead, we use our generate() wrapper which round-robins clients. - - """ - inputs = env._get_eval_inputs(num_examples, rollouts_per_example) - return await generate( - env=env, - clients=clients, - get_client=get_client, - model_name=model_name, - examples=inputs, - # _get_eval_inputs() already repeats the examples, this currently means - # we do not support eval envs with group scoring well -- this should be - # resolved once we can use vf.Environment.generate() and - # vf.Environment.evaluate() directly though - rollouts_per_example=1, - sampling_args=sampling_args, - max_retries=max_retries, - state_columns=state_columns, - ) - - # TODO: remove once usage is tracked by verifiers def get_prompt_len(output: vf.RolloutOutput) -> int: """ @@ -292,12 +62,6 @@ def get_completion_len(output: vf.RolloutOutput) -> int: return get_seq_len(output) - get_prompt_len(output) -def task_uses_group_scoring(env: vf.Environment, task_name: str) -> bool: - """Check if a task's rubric contains any group-level reward functions.""" - rubric = env.get_env_for_task(task_name).rubric - return any(rubric._is_group_func(func) for func in rubric._get_reward_funcs()) - - def intercept_vf_logging(logger: str = "verifiers", level: str = "DEBUG", prefix: str | None = None): """Intercepts verifiers logging and routes through prime-rl logger with optional prefix.""" vf_logger = logging.getLogger(logger) diff --git a/src/prime_rl/utils/utils.py b/src/prime_rl/utils/utils.py index 81a4df7357..af5bc251c3 100644 --- a/src/prime_rl/utils/utils.py +++ b/src/prime_rl/utils/utils.py @@ -297,15 +297,6 @@ def default_dtype(dtype): torch.set_default_dtype(prev) -def strip_env_version(env_id: str) -> str: - """Strip the @version suffix from an environment ID. - - Environment IDs may include a version (e.g. 'd42me/meow@0.1.5') for installation, - but the version must be stripped before loading as a Python module. - """ - return env_id.split("@")[0] - - def install_env(env_id: str) -> None: """Install an environment in subprocess.""" logger = get_logger() diff --git a/tests/unit/orchestrator/test_buffer.py b/tests/unit/orchestrator/test_buffer.py index 16c0070120..656660d42a 100644 --- a/tests/unit/orchestrator/test_buffer.py +++ b/tests/unit/orchestrator/test_buffer.py @@ -5,8 +5,28 @@ import verifiers as vf from datasets import Dataset -from prime_rl.configs.orchestrator import BufferConfig -from prime_rl.orchestrator.buffer import Buffer +from prime_rl.configs.orchestrator import BufferConfig, EnvConfig +from prime_rl.orchestrator.buffer import BufferSet +from prime_rl.orchestrator.envs import Envs, TrainEnv + + +def make_env(name: str, vf_env: vf.Environment, **config_kwargs) -> TrainEnv: + """Create a TrainEnv without calling vf.load_environment.""" + config = EnvConfig(id=name, name=name, **config_kwargs) + env = TrainEnv.__new__(TrainEnv) + env.config = config + env._env = vf_env + env._env_client = None + env._env_server_process = None + env.sampling_args = {} + return env + + +def make_envs(env_dict: dict[str, TrainEnv]) -> Envs: + """Create an Envs container from a dict of Env instances.""" + envs = Envs.__new__(Envs) + envs._envs = env_dict + return envs @pytest.fixture(autouse=True) @@ -32,8 +52,8 @@ def dummy_dataset() -> Dataset: @pytest.fixture -def dummy_env_group(mock_openai_client, dummy_dataset) -> vf.EnvGroup: - """Return an EnvGroup with two dummy envs using the same dataset.""" +def dummy_envs(mock_openai_client, dummy_dataset) -> Envs: + """Return an Envs with two dummy envs.""" env_a = vf.SingleTurnEnv( client=mock_openai_client, model="test-model", @@ -46,22 +66,29 @@ def dummy_env_group(mock_openai_client, dummy_dataset) -> vf.EnvGroup: dataset=dummy_dataset, rubric=vf.Rubric(), ) - return vf.EnvGroup(envs=[env_a, env_b], env_names=["env_a", "env_b"]) + return make_envs( + { + "env_a": make_env("env_a", env_a), + "env_b": make_env("env_b", env_b), + } + ) @pytest.fixture def make_rollouts(): - def _make_rollouts(dataset: Dataset, rewards: list[float]) -> list[vf.RolloutOutput]: + def _make_rollouts( + buffer: BufferSet, env_name: str, indices: list[int], rewards: list[float] + ) -> list[vf.RolloutOutput]: all_rollouts = [] - for i, reward in enumerate(rewards): - task = dataset[i]["task"] - example_id = dataset[i]["example_id"] - prompt = dataset[i]["prompt"] + eb = buffer.env_buffers[env_name] + examples = list(eb.examples.values()) + for idx, reward in zip(indices, rewards): + example = examples[idx] rollouts = [ vf.RolloutOutput( - example_id=example_id, - task=task, - prompt=prompt, + example_id=example["example_id"], + task=example["task"], + prompt=example["prompt"], prompt_ids=[0], prompt_mask=[1], completion_ids=[1], @@ -73,104 +100,107 @@ def _make_rollouts(dataset: Dataset, rewards: list[float]) -> list[vf.RolloutOut metrics={}, ) ] * 2 + for r in rollouts: + r["env_name"] = env_name all_rollouts.extend(rollouts) return all_rollouts return _make_rollouts -def get_normal_ids(buffer: Buffer) -> set[int]: - return {example_id for env in buffer.example_buffer.values() for example_id in env.keys()} +def get_normal_count(buffer: BufferSet) -> int: + return sum(eb.num_normal for eb in buffer.env_buffers.values()) -def test_buffer_init_and_sample(dummy_env_group): - dataset = dummy_env_group.get_dataset() - buffer = Buffer(dataset, dummy_env_group.env_names, BufferConfig()) - # Each env has 5 examples, so total is 10 - assert len(buffer.example_buffer["env_a"]) == 5 - assert len(buffer.example_buffer["env_b"]) == 5 +def test_buffer_init_and_sample(dummy_envs): + buffer = BufferSet(dummy_envs, BufferConfig()) + assert buffer.env_buffers["env_a"].num_normal == 5 + assert buffer.env_buffers["env_b"].num_normal == 5 samples = buffer.sample_examples(2) assert len(samples) == 2 -def test_buffer_problem_pool_assignment(dummy_env_group, make_rollouts): +def test_buffer_problem_pool_assignment(dummy_envs, make_rollouts): """Problems are moved to easy/hard pools based on reward thresholds.""" - dataset = dummy_env_group.get_dataset() - buffer = Buffer(dataset, dummy_env_group.env_names, BufferConfig(easy_threshold=1.0, hard_threshold=0.0)) - dataset = buffer.dataset - # Use first 5 examples (all from env_a since they come first in concatenated dataset) - buffer.update(make_rollouts(dataset.select(range(5)), rewards=[1.0, 1.0, 0.5, 0.5, 0.0])) + buffer = BufferSet(dummy_envs, BufferConfig(easy_threshold=1.0, hard_threshold=0.0)) + buffer.update(make_rollouts(buffer, "env_a", list(range(5)), rewards=[1.0, 1.0, 0.5, 0.5, 0.0])) - assert len(buffer.easy_examples) == 2 - assert len(buffer.hard_examples) == 1 - # 2 normal from first 5, plus 5 from env_b = 7 - assert len(get_normal_ids(buffer)) == 7 + assert len(buffer.env_buffers["env_a"].easy_examples) == 2 + assert len(buffer.env_buffers["env_a"].hard_examples) == 1 + # 2 normal from env_a + 5 from env_b = 7 + assert get_normal_count(buffer) == 7 -def test_buffer_online_difficulty_filtering(dummy_env_group, make_rollouts): +def test_buffer_online_difficulty_filtering(dummy_envs, make_rollouts): """With online_difficulty_filtering=True, only partial reward rollouts are kept.""" - dataset = dummy_env_group.get_dataset() - buffer = Buffer( - dataset, - dummy_env_group.env_names, + buffer = BufferSet( + dummy_envs, BufferConfig(online_difficulty_filtering=True), ) - buffer.update(make_rollouts(dataset.select(range(5)), rewards=[1.0, 0.5, 0.0, 0.5, 0.5])) + buffer.update(make_rollouts(buffer, "env_a", list(range(5)), rewards=[1.0, 0.5, 0.0, 0.5, 0.5])) # Only 3 problems with reward 0.5 -> 6 rollouts kept assert len(buffer.rollout_buffer) == 6 -def test_buffer_no_filtering_by_default(dummy_env_group, make_rollouts): +def test_buffer_no_filtering_by_default(dummy_envs, make_rollouts): """With online_difficulty_filtering=False (default), all rollouts are kept.""" - dataset = dummy_env_group.get_dataset() - buffer = Buffer(dataset, dummy_env_group.env_names, BufferConfig()) - buffer.update(make_rollouts(dataset.select(range(5)), rewards=[1.0, 0.5, 0.0, 0.5, 0.5])) + buffer = BufferSet(dummy_envs, BufferConfig()) + buffer.update(make_rollouts(buffer, "env_a", list(range(5)), rewards=[1.0, 0.5, 0.0, 0.5, 0.5])) # All 5 problems -> 10 rollouts kept assert len(buffer.rollout_buffer) == 10 -def test_buffer_save_load_with_conversion(dummy_env_group, make_rollouts, tmp_path): +def test_buffer_save_load_with_conversion(dummy_envs, make_rollouts, tmp_path): """Easy/hard problems are partially converted to normal on load.""" - dataset = dummy_env_group.get_dataset() - buffer = Buffer(dataset, dummy_env_group.env_names, BufferConfig(easy_threshold=1.0, hard_threshold=0.0)) - buffer.update(make_rollouts(dataset.select(range(5)), rewards=[1.0, 1.0, 0.5, 0.5, 0.0])) + buffer = BufferSet(dummy_envs, BufferConfig(easy_threshold=1.0, hard_threshold=0.0)) + buffer.update(make_rollouts(buffer, "env_a", list(range(5)), rewards=[1.0, 1.0, 0.5, 0.5, 0.0])) buffer.save(tmp_path / "buffer") - new_buffer = Buffer( - dataset, dummy_env_group.env_names, BufferConfig(easy_fraction=0.5, hash_keys=["prompt", "task"]) - ) + new_buffer = BufferSet(dummy_envs, BufferConfig(easy_fraction=0.5, hash_keys=["prompt", "env_name"])) new_buffer.load(tmp_path / "buffer") # 1 of 2 easy problems converted to normal - assert len(new_buffer.easy_examples) == 1 + assert len(new_buffer.env_buffers["env_a"].easy_examples) == 1 # 2 were normal + 5 from env_b + 1 converted from easy = 8 - assert len(get_normal_ids(new_buffer)) == 8 + assert get_normal_count(new_buffer) == 8 -def test_buffer_env_ratios(dummy_env_group): - dataset = dummy_env_group.get_dataset() - buffer = Buffer(dataset, dummy_env_group.env_names, BufferConfig(env_ratios=[0.8, 0.2])) - assert len(buffer.example_buffer["env_a"]) == 5 - assert len(buffer.example_buffer["env_b"]) == 5 +def test_buffer_env_ratios(mock_openai_client, dummy_dataset): + env_a = vf.SingleTurnEnv(client=mock_openai_client, model="test-model", dataset=dummy_dataset, rubric=vf.Rubric()) + env_b = vf.SingleTurnEnv(client=mock_openai_client, model="test-model", dataset=dummy_dataset, rubric=vf.Rubric()) + envs = make_envs( + { + "env_a": make_env("env_a", env_a, ratio=0.8), + "env_b": make_env("env_b", env_b, ratio=0.2), + } + ) + + buffer = BufferSet(envs, BufferConfig()) + assert buffer.env_buffers["env_a"].num_normal == 5 + assert buffer.env_buffers["env_b"].num_normal == 5 samples = buffer.sample_examples(100) - env_a_count = sum(1 for p in samples if p["task"] == "env_a") + env_a_count = sum(1 for p in samples if p["env_name"] == "env_a") assert 60 <= env_a_count <= 95 def test_buffer_env_ratios_validation(): - """BufferConfig validates that all env_ratios are positive.""" + """Validates that env ratios must be positive and all-or-nothing.""" from pydantic import ValidationError - with pytest.raises(ValidationError, match="All env_ratios must be positive"): - BufferConfig(env_ratios=[0.5, -0.3, 0.2]) + from prime_rl.configs.orchestrator import TrainEnvsConfig + + with pytest.raises(ValidationError): + EnvConfig(id="env_a", ratio=-0.3) + + with pytest.raises(ValidationError, match="mix of set and unset"): + TrainEnvsConfig(env=[EnvConfig(id="a", ratio=0.5), EnvConfig(id="b")]) def test_buffer_no_cross_env_pool_assignment(mock_openai_client, tmp_path): - """Pool assignments don't transfer if example_id exists but task/env changed.""" - # Original: create an env_group with only env_a + """Pool assignments don't transfer if example_id exists but env changed.""" original_dataset = Dataset.from_dict({"question": ["q0"], "answer": ["a0"]}) original_env = vf.SingleTurnEnv( client=mock_openai_client, @@ -178,17 +208,15 @@ def test_buffer_no_cross_env_pool_assignment(mock_openai_client, tmp_path): dataset=original_dataset, rubric=vf.Rubric(), ) - original_env_group = vf.EnvGroup(envs=[original_env], env_names=["env_a"]) - original_env_dataset = original_env_group.get_dataset() - - buffer = Buffer(original_env_dataset, original_env_group.env_names, BufferConfig(easy_threshold=1.0)) - # Manually move the example to easy pool - example_id = list(buffer.example_buffer["env_a"].keys())[0] - example = buffer.example_buffer["env_a"].pop(example_id) - buffer.easy_examples.append(example) + original_env_set = make_envs({"env_a": make_env("env_a", original_env)}) + + buffer = BufferSet(original_env_set, BufferConfig(easy_threshold=1.0)) + eb = buffer.env_buffers["env_a"] + example_id = list(eb.examples.keys())[0] + example = eb.examples.pop(example_id) + eb.easy_examples.append(example) buffer.save(tmp_path / "buffer") - # Resume: create a new env_group with different dataset but similar structure new_dataset = Dataset.from_dict({"question": ["different_q"], "answer": ["different_a"]}) new_env = vf.SingleTurnEnv( client=mock_openai_client, @@ -196,13 +224,10 @@ def test_buffer_no_cross_env_pool_assignment(mock_openai_client, tmp_path): dataset=new_dataset, rubric=vf.Rubric(), ) - new_env_group = vf.EnvGroup(envs=[new_env], env_names=["env_b"]) - new_env_dataset = new_env_group.get_dataset() + new_env_set = make_envs({"env_b": make_env("env_b", new_env)}) - new_buffer = Buffer(new_env_dataset, new_env_group.env_names, BufferConfig()) + new_buffer = BufferSet(new_env_set, BufferConfig()) new_buffer.load(tmp_path / "buffer") - # Should NOT be in easy pool (different content, different hash) - assert len(new_buffer.easy_examples) == 0 - # Should still be in normal pool for env_b - assert len(new_buffer.example_buffer["env_b"]) == 1 + assert len(new_buffer.env_buffers["env_b"].easy_examples) == 0 + assert new_buffer.env_buffers["env_b"].num_normal == 1