Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions examples/Intellect-3.1/rl.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,30 +50,34 @@ oversampling_factor = 2
[[orchestrator.env]]
id = "mini-swe-agent-plus"
name = "swe"
ratio = 0.3
args = { max_turns = 200, cpu_cores = 2, memory_gb = 4, disk_size_gb = 4, labels = ["mini-swe-agent-plus"], total_timeout_minutes = 720, sandbox_client_max_workers = 256, max_command_timeouts = 3, sandbox_command_timeout = 30}

[[orchestrator.env]]
id = "deepdive"
name = "deepdive"
ratio = 0.2
args = { finish_with_tool = true, open_max_workers = 128, cache_dir = "/tmp/i3_deepdive_cache_train" }

[[orchestrator.env]]
id = "math-env"
name = "math"
ratio = 0.3
args = { min_avg_reward = 0.0, max_avg_reward = 0.874}

[[orchestrator.env]]
id = "logic-env"
args = { min_avg_reward = 0.0, max_avg_reward = 0.874 }
name = "logic"
ratio = 0.2
args = { min_avg_reward = 0.0, max_avg_reward = 0.874 }

[[orchestrator.env]]
id = "code-env"
name = "code"
ratio = 0.2
args = { pool_size = 512 }

[orchestrator.buffer]
env_ratios = [0.3, 0.2, 0.3, 0.2, 0.2]
easy_threshold = 1.0
online_difficulty_filtering = true
seed = 42
Expand Down
198 changes: 98 additions & 100 deletions src/prime_rl/configs/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,41 +268,58 @@ class EvalSaveHFConfig(BaseConfig):


class EnvConfig(BaseConfig):
"""Configures an environment for training."""
"""Base environment configuration."""

id: Annotated[str, Field(description="ID of the environment to use.")] = "reverse-text"
args: Annotated[dict, Field(description="Arguments to pass to the environment.")] = {}
name: Annotated[str | None, Field(description="Name of the environment to use.")] = None
address: Annotated[
str | None,
Field(
description="Address of the environment server. If None, will spawn an environment server in a subprocess automatically.If given, will try to connect an environment client to the environment server at this address."
),
] = None
extra_env_kwargs: Annotated[
dict[str, Any],
Field(
description=(
"Extra kwargs passed to an env (e.g. seq_len, score_rollouts). This field is auto-populated with the seq_len, and score_rollouts for training envs on the orchestrator. It is generally NOT recommended for this field to be overriden by the user. It's main use case is to match the extra_env_kwargs when running an env in an isolated environment server."
"Extra kwargs passed to an env (e.g. seq_len, score_rollouts). This field is auto-populated "
"by the orchestrator. It is generally NOT recommended for this field to be overridden by the user."
),
),
] = {}

address: Annotated[
str | None,
Field(
description="Address of the environment server. If None, will spawn an environment server in a subprocess automatically. If given, will connect an environment client to the server at this address."
),
] = None

num_workers: Annotated[
int | Literal["auto"],
Field(
description=(
"Number of env server worker processes. "
"Set to 'auto' to scale based on the env's concurrency (1 worker per 256 concurrent rollouts). "
"When setting manually, we recommend sizing so that each worker handles at most 256 concurrent rollouts. "
"Only used when the orchestrator spawns the env server (i.e. address is None)."
),
description="Number of env server worker processes. "
"Set to 'auto' to scale based on the env's concurrency (1 worker per 256 concurrent rollouts). "
"When setting manually, we recommend sizing so that each worker handles at most 256 concurrent rollouts. "
"Only used when the orchestrator spawns the env server (i.e. address is None)."
),
] = "auto"

score_rollouts: Annotated[
bool,
Field(
description="Whether to score rollouts using the environment rubric. If False, rewards are always set to 0.",
),
] = True

ratio: Annotated[
float | None,
Field(
gt=0,
description="Sampling ratio for this environment in the buffer. If None for all envs, samples uniformly across all available problems. If set, ratios across all envs must sum to 1.",
),
] = None

max_retries: Annotated[
int,
Field(
ge=0,
description="Maximum number of times the environment will retry a failed rollout.",
description="Maximum number of internal retries the environment will attempt before declaring a rollout failed.",
),
] = 0

Expand All @@ -319,54 +336,71 @@ def validate_env_name(self):
return self


TrainEnvConfig = EnvConfig


class EvalEnvConfig(EnvConfig):
"""Configures an environment for evaluation."""
"""Configures an evaluation environment."""

num_examples: Annotated[
int | None,
Field(
description="Number of examples to evaluate per environment. If not set, will use 'num_examples' from main config."
),
] = None
rollouts_per_example: Annotated[
int | None,
Field(
description="Number of samples to generate per example for each environment. If not set, will use 'rollouts_per_example' from main config."
),
] = None
int,
Field(description="Number of examples to evaluate."),
] = -1

skip_first: Annotated[
rollouts_per_example: Annotated[
int,
Field(
description="Number of examples to skip from the beginning of the dataset.",
),
] = 0
Field(ge=1, description="Number of rollouts to generate per example."),
] = 1


class ValConfig(BaseConfig):
"""Configures the validation of the model."""
class TrainEnvsConfig(BaseConfig):
"""Configures all training environments."""

num_examples: Annotated[
int, Field(ge=1, description="Number of examples to use for validation. If -1, will use all examples.")
] = 16
rollouts_per_example: Annotated[
int, Field(ge=1, description="Number of samples to generate per example for validation.")
] = 1
interval: Annotated[int, Field(description="Interval at which to validate the model.")] = 10
env: list[TrainEnvConfig] = [TrainEnvConfig()]

@model_validator(mode="after")
def validate_unique_env_names(self):
env_names = [env.resolved_name for env in self.env]
duplicates = [n for n in env_names if env_names.count(n) > 1]
if duplicates:
raise ValueError(
f"Duplicate training environment names: {set(duplicates)}. Each env must have a unique name."
)
return self

@model_validator(mode="after")
def validate_env_ratios(self):
ratios = [env.ratio for env in self.env]
if all(r is None for r in ratios):
return self
if any(r is None for r in ratios):
raise ValueError("Either all envs must have a ratio or none of them. Got a mix of set and unset ratios.")
return self

class EvalConfig(BaseConfig):
"""Configures evaluation using verifiers environments."""

class EvalEnvsConfig(BaseConfig):
"""Configures all evaluation environments."""

env: list[EvalEnvConfig] = [EvalEnvConfig()]

@model_validator(mode="after")
def validate_unique_env_names(self):
env_names = [env.resolved_name for env in self.env]
duplicates = [n for n in env_names if env_names.count(n) > 1]
if duplicates:
raise ValueError(
f"Duplicate evaluation environment names: {set(duplicates)}. Each env must have a unique name."
)
return self


class EvalConfig(EvalEnvsConfig):
"""Configures evaluation using verifiers environments."""

sampling: EvalSamplingConfig = Field(
default_factory=EvalSamplingConfig,
description="Shared sampling configuration for evals; can differ from training sampling.",
)
num_examples: Annotated[int, Field(description="Number of examples to evaluate per environment.")] = -1
rollouts_per_example: Annotated[
int, Field(ge=1, description="Number of samples to generate per example for each environment.")
] = 1

interval: Annotated[
int,
Expand Down Expand Up @@ -401,14 +435,6 @@ class EvalConfig(BaseConfig):
),
] = False

@model_validator(mode="after")
def validate_unique_env_names(self):
env_names = [env.resolved_name for env in self.env]
duplicates = [n for n in env_names if env_names.count(n) > 1]
if duplicates:
raise ValueError(f"Duplicate eval environment names: {set(duplicates)}. Each env must have a unique name.")
return self


class CheckpointConfig(BaseConfig):
"""Configures checkpointing the orchestrator."""
Expand Down Expand Up @@ -472,16 +498,6 @@ class BufferConfig(BaseConfig):
),
] = None

env_ratios: Annotated[
list[float] | None,
Field(
description=(
"Ratios for sampling from each environment. "
"If None, samples uniformly across all available problems (not environments)."
),
),
] = None

easy_threshold: Annotated[
float | None,
Field(
Expand Down Expand Up @@ -535,12 +551,6 @@ def validate_thresholds(self):
assert self.easy_threshold > self.hard_threshold, "easy_threshold must be greater than hard_threshold."
return self

@model_validator(mode="after")
def validate_env_ratios(self):
if self.env_ratios is not None:
assert all(ratio > 0 for ratio in self.env_ratios), "All env_ratios must be positive."
return self


class VerificationConfig(BaseConfig):
"""Configures rollout verification and rubric scoring."""
Expand Down Expand Up @@ -703,7 +713,7 @@ class TeacherRolloutModelConfig(BaseConfig):
] = ModelConfig()


class OrchestratorConfig(BaseConfig):
class OrchestratorConfig(TrainEnvsConfig):
"""Configures the orchestrator for RL training."""

# The OAI client configuration
Expand Down Expand Up @@ -739,9 +749,6 @@ class OrchestratorConfig(BaseConfig):
# The sampling configuration
sampling: SamplingConfig = SamplingConfig()

# The environment configuration
env: list[EnvConfig] = [EnvConfig()]

# The evaluation configuration
eval: EvalConfig | None = None

Expand Down Expand Up @@ -769,9 +776,6 @@ class OrchestratorConfig(BaseConfig):
# The checkpoint configuration
ckpt: CheckpointConfig | None = None

# The validation configuration
val: ValConfig | None = None

weight_broadcast: WeightBroadcastConfig = FileSystemWeightBroadcastConfig()

rollout_transport: TransportConfig = FileSystemTransportConfig()
Expand Down Expand Up @@ -908,6 +912,16 @@ class OrchestratorConfig(BaseConfig):
),
] = True

@property
def train_envs(self) -> list[TrainEnvConfig]:
return self.env

@property
def eval_envs(self) -> list[EvalEnvConfig]:
if self.eval is None:
return []
return self.eval.env

@model_validator(mode="after")
def validate_unique_filter_types(self):
types = [f.type for f in self.filters]
Expand Down Expand Up @@ -960,20 +974,6 @@ def resolve_batching(self):
raise ValueError("max_inflight_rollouts must be at least the number of rollouts per example")
return self

@model_validator(mode="after")
def validate_unique_env_names(self):
env_names = [env.resolved_name for env in self.env]
duplicates = [n for n in env_names if env_names.count(n) > 1]
if duplicates:
raise ValueError(f"Duplicate environment names: {set(duplicates)}. Each env must have a unique name.")
return self

@model_validator(mode="after")
def validate_env_ratios(self):
if self.buffer.env_ratios is not None:
assert len(self.buffer.env_ratios) == len(self.env), "env_ratios length must match number of environments"
return self

@model_validator(mode="after")
def validate_verification_config(self):
if self.verification.enabled:
Expand Down Expand Up @@ -1019,15 +1019,13 @@ def auto_setup_bench(self):
return self

@model_validator(mode="after")
def resolve_extra_env_kwargs(self):
train_extra_env_kwargs = dict(
max_seq_len=self.seq_len,
score_rollouts=self.verification.enabled,
)
def resolve_env_config(self):
"""Populate extra_env_kwargs from top-level and per-env fields."""
for env in self.env:
# extra_env_kwargs is not meant to be used by the user, we shamelessly override here
env.extra_env_kwargs.update(train_extra_env_kwargs)

env.extra_env_kwargs.update(
max_seq_len=self.seq_len,
score_rollouts=env.score_rollouts,
)
return self

@model_validator(mode="after")
Expand Down
Loading
Loading