Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/prime_rl/utils/monitor/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
self.logger = get_logger()
self.history: list[dict[str, Any]] = []
self.output_dir = output_dir
self._jsonl_file = None

rank = int(os.environ.get("RANK", os.environ.get("DP_RANK", "0")))
self.enabled = self.config is not None
Expand All @@ -46,6 +47,7 @@ def __init__(
self._maybe_overwrite_wandb_command()

shared_mode = os.environ.get("WANDB_SHARED_MODE") == "1"
self._setup_jsonl(shared_mode)
if shared_mode:
run_id = os.environ.get("WANDB_SHARED_RUN_ID")
label = os.environ.get("WANDB_SHARED_LABEL")
Expand Down Expand Up @@ -108,6 +110,16 @@ def init_wandb(max_retries: int):
log_mode="INCREMENTAL",
)

def _setup_jsonl(self, shared_mode: bool) -> None:
"""Open a JSONL file for human-readable metric logging."""
if self.output_dir is None:
return
label = os.environ.get("WANDB_SHARED_LABEL", "metrics") if shared_mode else "metrics"
path = self.output_dir / "metrics" / f"{label}.jsonl"
path.parent.mkdir(parents=True, exist_ok=True)
self._jsonl_file = open(path, "a")
self.logger.info(f"Logging metrics to {path}")

def _maybe_overwrite_wandb_command(self) -> None:
"""Overwrites sys.argv with the start command if it is set in the environment variables."""
wandb_args = os.environ.get("WANDB_ARGS", None)
Expand All @@ -122,6 +134,9 @@ def log(self, metrics: dict[str, Any], step: int) -> None:
if not self.enabled:
return
wandb.log({**metrics, "step": step})
if self._jsonl_file is not None:
self._jsonl_file.write(json.dumps({**metrics, "step": step}, default=str) + "\n")
self._jsonl_file.flush()

def log_samples(self, rollouts: list[vf.RolloutOutput], step: int) -> None:
"""Logs rollouts to W&B table."""
Expand Down Expand Up @@ -228,6 +243,11 @@ def log_distributions(self, distributions: dict[str, list[float]], step: int) ->
"""Log distributions (no-op for W&B)."""
pass

def close(self) -> None:
if self._jsonl_file is not None:
self._jsonl_file.close()
self._jsonl_file = None

def save_final_summary(self, filename: str = "final_summary.json") -> None:
"""Save final summary to W&B table."""
if not self.is_master or not self.enabled:
Expand Down
Loading