Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/platform-monitoring.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Platform Monitoring

Use `orchestrator.prime_monitor` to register a run on the Prime Intellect platform and stream training metrics, samples, and distributions.
Use `orchestrator.prime` to register a run on the Prime Intellect platform and stream training metrics, samples, and distributions.

> **Internal-only for now:** external run registration is currently only enabled for internal / allowlisted teams.

Expand All @@ -23,14 +23,14 @@ export PRIME_API_KEY=pit_...
## Minimal config

```toml
[orchestrator.prime_monitor]
[orchestrator.prime]
run_name = "my-experiment"
```

You can also override from the CLI:

```bash
uv run rl @ config.toml --orchestrator.prime_monitor.run_name "my-experiment"
uv run rl @ config.toml --orchestrator.prime.run-name "my-experiment"
```

## Troubleshooting
Expand Down
22 changes: 18 additions & 4 deletions src/prime_rl/configs/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
FileSystemTransportConfig,
HeartbeatConfig,
LogConfig,
PrimeMonitorConfig,
PrimeConfig,
TransportConfig,
WandbWithExtrasConfig,
)
Expand Down Expand Up @@ -764,7 +764,12 @@ class OrchestratorConfig(BaseConfig):
wandb: WandbWithExtrasConfig | None = None

# The prime monitor configuration
prime_monitor: PrimeMonitorConfig | None = None
prime: Annotated[
PrimeConfig | None,
Field(
validation_alias=AliasChoices("prime", "prime_monitor"),
),
] = None

# The checkpoint configuration
ckpt: CheckpointConfig | None = None
Expand Down Expand Up @@ -908,6 +913,15 @@ class OrchestratorConfig(BaseConfig):
),
] = True

@model_validator(mode="before")
@classmethod
def warn_deprecated_prime_monitor(cls, data):
if isinstance(data, dict) and "prime_monitor" in data:
from prime_rl.utils.logger import get_logger

get_logger().warning("Config: [prime_monitor] will be deprecated in a future version, use [prime] instead")
return data

@model_validator(mode="after")
def validate_unique_filter_types(self):
types = [f.type for f in self.filters]
Expand Down Expand Up @@ -1013,8 +1027,8 @@ def auto_setup_bench(self):
self.eval = None
if self.wandb:
self.wandb.log_extras = None
if self.prime_monitor:
self.prime_monitor.log_extras = None
if self.prime:
self.prime.log_extras = None

return self

Expand Down
4 changes: 2 additions & 2 deletions src/prime_rl/configs/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,9 +513,9 @@ def auto_setup_wandb(self):

validate_shared_wandb_config(self.trainer, self.orchestrator)

if self.orchestrator.prime_monitor is not None and self.orchestrator.prime_monitor.run_name is None:
if self.orchestrator.prime is not None and self.orchestrator.prime.run_name is None:
if self.wandb and self.wandb.name:
self.orchestrator.prime_monitor.run_name = self.wandb.name
self.orchestrator.prime.run_name = self.wandb.name

return self

Expand Down
2 changes: 1 addition & 1 deletion src/prime_rl/configs/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ class WandbWithExtrasConfig(WandbConfig):
] = LogExtrasConfig()


class PrimeMonitorConfig(BaseConfig):
class PrimeConfig(BaseConfig):
"""Configures logging to Prime Intellect API."""

base_url: Annotated[
Expand Down
26 changes: 12 additions & 14 deletions src/prime_rl/orchestrator/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from prime_rl.configs.orchestrator import EvalSamplingConfig
from prime_rl.orchestrator.vf_utils import evaluate, get_completion_len
from prime_rl.utils.logger import get_logger
from prime_rl.utils.monitor import get_monitor
from prime_rl.utils.utils import capitalize


Expand Down Expand Up @@ -99,7 +98,12 @@ async def evaluate_env(
ckpt_step: int,
step: int,
get_client: Callable[[], Awaitable[vf.ClientConfig]],
):
) -> tuple[dict[str, Any], list[vf.RolloutOutput]]:
"""Run evaluation and return (metrics_dict, rollout_outputs).

Returns metrics prefixed with ``eval/{env_name}/`` and the raw outputs
so the caller can log them to any monitor backend.
"""
logger = get_logger()
logger.info(f"Evaluating {env_name} ({num_examples=}, {rollouts_per_example=})")
eval_start_time = time.perf_counter()
Expand All @@ -118,12 +122,8 @@ async def evaluate_env(

if not outputs:
logger.warning(f"All rollouts failed for {env_name} ({failed_rollouts} failed), skipping metrics")
monitor = get_monitor()
monitor.log(
{f"eval/{env_name}/failed_rollouts": failed_rollouts, "progress/ckpt_step": ckpt_step, "step": step},
step=step,
)
return
metrics = {f"eval/{env_name}/failed_rollouts": failed_rollouts, "progress/ckpt_step": ckpt_step, "step": step}
return metrics, []

rows = []
for output in outputs:
Expand Down Expand Up @@ -164,7 +164,7 @@ async def evaluate_env(
)
logger.success(message)

# Log statistics to monitor
# Build metrics dict
eval_metrics = {
f"avg@{rollouts_per_example}": float(results_df.reward.mean()),
"no_response/mean": float(results_df.no_response.mean()),
Expand All @@ -179,8 +179,6 @@ async def evaluate_env(
if could_be_binary:
assert pass_at_k is not None
eval_metrics.update(pd.Series(pass_at_k.mean()).to_dict())
eval_metrics = {**{f"eval/{env_name}/{k}": v for k, v in eval_metrics.items()}}
eval_metrics.update({"progress/ckpt_step": ckpt_step, "step": step})
monitor = get_monitor()
monitor.log(eval_metrics, step=step)
monitor.log_eval_samples(outputs, env_name=env_name, step=step)
metrics = {f"eval/{env_name}/{k}": v for k, v in eval_metrics.items()}
metrics.update({"progress/ckpt_step": ckpt_step, "step": step})
return metrics, outputs
65 changes: 41 additions & 24 deletions src/prime_rl/orchestrator/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
from prime_rl.utils.config import cli
from prime_rl.utils.heartbeat import Heartbeat
from prime_rl.utils.logger import setup_logger
from prime_rl.utils.monitor import setup_monitor
from prime_rl.utils.prime_monitor import PrimeMonitor
from prime_rl.utils.process import set_proc_title
from prime_rl.utils.temp_scheduling import compute_temperature
from prime_rl.utils.utils import (
Expand All @@ -76,6 +76,7 @@
strip_env_version,
to_col_format,
)
from prime_rl.utils.wandb_monitor import WandbMonitor


@clean_exit
Expand Down Expand Up @@ -149,11 +150,15 @@ async def orchestrate(config: OrchestratorConfig):
config.model.name, trust_remote_code=config.model.trust_remote_code, use_fast=True
)

# Setup monitor
logger.info(f"Initializing monitor (wandb={config.wandb}, prime_monitor={config.prime_monitor})")
monitor = setup_monitor(
wandb_config=config.wandb,
prime_config=config.prime_monitor,
# Setup monitors
wandb_monitor = WandbMonitor(
config=config.wandb,
output_dir=config.output_dir,
tokenizer=tokenizer,
run_config=config,
)
prime_monitor = PrimeMonitor(
config=config.prime,
output_dir=config.output_dir,
tokenizer=tokenizer,
run_config=config,
Expand Down Expand Up @@ -480,7 +485,7 @@ def _cleanup_env_processes():
logger.info("Cancelling in-flight training rollouts before starting evals to avoid congestion.")
await scheduler.cancel_inflight_rollouts()

results = await asyncio.gather(
eval_results = await asyncio.gather(
*[
evaluate_env(
env=eval_env,
Expand All @@ -498,6 +503,12 @@ def _cleanup_env_processes():
]
)

for (eval_metrics, eval_outputs), eval_env_name in zip(eval_results, eval_env_names):
wandb_monitor.log(eval_metrics, step=progress.step)
prime_monitor.log(eval_metrics, step=progress.step)
if eval_outputs:
wandb_monitor.log_eval_samples(eval_outputs, env_name=eval_env_name, step=progress.step)

# Resume weight updates
scheduler.checkpoint_ready.set()

Expand Down Expand Up @@ -819,18 +830,17 @@ def compute_solve_rates(df):
to_log[f"val/reward/{env}/max"] = env_by_example.reward.mean().max()
to_log[f"val/reward/{env}/min"] = env_by_example.reward.mean().min()

# Log metrics to monitor(s)
monitor.log(to_log, step=progress.step)

# Log samples to monitor(s) if enabled.
monitor.log_samples(train_rollouts, step=progress.step)

# Log distributions (rewards, advantages) if enabled
monitor.log_distributions(
distributions={
"rewards": rewards,
"advantages": advantages,
},
# Log metrics to monitors
wandb_monitor.log(to_log, step=progress.step)
wandb_monitor.log_samples(train_rollouts, step=progress.step)
wandb_monitor.log_distributions(
distributions={"rewards": rewards, "advantages": advantages},
step=progress.step,
)
prime_monitor.log(to_log, step=progress.step)
prime_monitor.log_samples(train_rollouts, step=progress.step)
prime_monitor.log_distributions(
distributions={"rewards": rewards, "advantages": advantages},
step=progress.step,
)

Expand Down Expand Up @@ -859,7 +869,7 @@ def compute_solve_rates(df):

if config.eval:
logger.info("Running final evals")
results = await asyncio.gather(
eval_results = await asyncio.gather(
*[
evaluate_env(
env=eval_env,
Expand All @@ -877,9 +887,16 @@ def compute_solve_rates(df):
]
)

# Log final (immutable) samples and distributions to monitor(s)
monitor.log_final_samples()
monitor.save_final_summary()
for (eval_metrics, eval_outputs), eval_env_name in zip(eval_results, eval_env_names):
wandb_monitor.log(eval_metrics, step=progress.step)
prime_monitor.log(eval_metrics, step=progress.step)
if eval_outputs:
wandb_monitor.log_eval_samples(eval_outputs, env_name=eval_env_name, step=progress.step)

# Log final (immutable) samples and distributions to monitors
wandb_monitor.log_final_samples()
wandb_monitor.save_final_summary()
prime_monitor.save_final_summary()

# Write final checkpoint
if ckpt_manager is not None:
Expand Down Expand Up @@ -912,7 +929,7 @@ def compute_solve_rates(df):

# Optionally, print benchmark table
if config.bench:
print_benchmark(to_col_format(monitor.history))
print_benchmark(to_col_format(wandb_monitor.history))


def main():
Expand Down
18 changes: 9 additions & 9 deletions src/prime_rl/trainer/rl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from prime_rl.trainer.models.layers.lora import set_lora_num_tokens
from prime_rl.utils.heartbeat import Heartbeat
from prime_rl.utils.metrics_server import HealthServer, MetricsServer, RunStats
from prime_rl.utils.monitor import setup_monitor
from prime_rl.utils.wandb_monitor import WandbMonitor
from prime_rl.utils.config import cli
from prime_rl.utils.process import set_proc_title
from prime_rl.utils.utils import clean_exit, resolve_latest_ckpt_step, to_col_format
Expand All @@ -80,7 +80,7 @@ def train(config: TrainerConfig):

# Setup the monitor
logger.info(f"Initializing monitor ({config.wandb})")
monitor = setup_monitor(config.wandb, output_dir=config.output_dir, run_config=config)
wandb_monitor = WandbMonitor(config=config.wandb, output_dir=config.output_dir, run_config=config)

# Setup heartbeat (only on rank 0)
heart = None
Expand Down Expand Up @@ -523,7 +523,7 @@ def load_run_checkpoint(_optimizer, idx: int) -> None:
"perf/peak_memory": peak_memory,
"step": progress.step,
}
monitor.log(perf_metrics, step=progress.step)
wandb_monitor.log(perf_metrics, step=progress.step)

# Log optimizer metrics
optim_metrics = {
Expand All @@ -532,7 +532,7 @@ def load_run_checkpoint(_optimizer, idx: int) -> None:
"optim/zero_grad_ratio": zero_grad_ratio,
"step": progress.step,
}
monitor.log(optim_metrics, step=progress.step)
wandb_monitor.log(optim_metrics, step=progress.step)

# Compute derived metrics
entropy_mean = tensor_stats.get("entropy/mean", 0.0)
Expand All @@ -542,7 +542,7 @@ def load_run_checkpoint(_optimizer, idx: int) -> None:

# Log tensor stats
tensor_stats["step"] = progress.step
monitor.log(tensor_stats, step=progress.step)
wandb_monitor.log(tensor_stats, step=progress.step)

# Log time metrics
time_metrics = {
Expand All @@ -554,12 +554,12 @@ def load_run_checkpoint(_optimizer, idx: int) -> None:
"time/forward_backward": forward_backward_time,
"step": progress.step,
}
monitor.log(time_metrics, step=progress.step)
wandb_monitor.log(time_metrics, step=progress.step)

# Log disk metrics
disk_metrics = get_ckpt_disk_metrics(config.output_dir)
disk_metrics["step"] = progress.step
monitor.log(disk_metrics, step=progress.step)
wandb_monitor.log(disk_metrics, step=progress.step)

# Update Prometheus metrics if configured
if metrics_server is not None:
Expand Down Expand Up @@ -628,7 +628,7 @@ def load_run_checkpoint(_optimizer, idx: int) -> None:
weight_ckpt_manager.save(progress.step, model, tokenizer)
weight_ckpt_manager.maybe_clean()

logger.info(f"Peak memory: {max(to_col_format(monitor.history)['perf/peak_memory']):.1f} GiB")
logger.info(f"Peak memory: {max(to_col_format(wandb_monitor.history)['perf/peak_memory']):.1f} GiB")
logger.success("RL trainer finished!")

# Stop metrics/health server if configured
Expand All @@ -639,7 +639,7 @@ def load_run_checkpoint(_optimizer, idx: int) -> None:

# Optionally, print benchmark table and export JSON
if config.bench is not None and world.is_master:
history = to_col_format(monitor.history)
history = to_col_format(wandb_monitor.history)
print_benchmark(history)
if config.bench.output_json:
export_benchmark_json(history, config.bench.output_json)
Expand Down
Loading
Loading