diff --git a/src/prime_rl/trainer/rl/packer.py b/src/prime_rl/trainer/rl/packer.py index cf9dcfa02e..0007fc47ab 100644 --- a/src/prime_rl/trainer/rl/packer.py +++ b/src/prime_rl/trainer/rl/packer.py @@ -281,6 +281,13 @@ def _update_run_progress(self, run_idx: int, num_samples: int, num_tokens: int) self.multi_run_manager.progress[run_idx].total_tokens += num_tokens self.multi_run_manager.progress[run_idx].total_samples += num_samples + def get_buffer_stats(self) -> tuple[dict[str, int], int]: + """Return per-run buffer lengths and round-robin position for metrics.""" + buffer_lens = { + self.multi_run_manager.idx_2_id[idx]: len(self.buffers[idx]) for idx in self.multi_run_manager.used_idxs + } + return buffer_lens, self._round_robin_position + def pack(self): """Pack samples from buffers using round-robin fair scheduling.""" self._get_batch() diff --git a/src/prime_rl/trainer/rl/train.py b/src/prime_rl/trainer/rl/train.py index 626ed151c7..9fb37fc327 100644 --- a/src/prime_rl/trainer/rl/train.py +++ b/src/prime_rl/trainer/rl/train.py @@ -17,6 +17,7 @@ from prime_rl.trainer.scheduler import setup_scheduler, setup_multi_scheduler from prime_rl.configs.trainer import TrainerConfig from prime_rl.trainer.rl.data import DataLoader, FakeDataLoader +from prime_rl.trainer.rl.packer import MultiPacker from prime_rl.utils.cp import ( gather_for_cp, gather_for_cp_wo_grad, @@ -596,6 +597,9 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: runs_max=multi_run_manager.max_runs, run_stats=run_stats, ) + if isinstance(dataloader, DataLoader) and isinstance(dataloader.packer, MultiPacker): + buffer_lengths, rr_pos = dataloader.packer.get_buffer_stats() + metrics_server.update_packer(buffer_lengths, rr_pos) progress.step += 1 is_first_step = False diff --git a/src/prime_rl/utils/metrics_server.py b/src/prime_rl/utils/metrics_server.py index f19736edf8..04b0310fc6 100644 --- a/src/prime_rl/utils/metrics_server.py +++ b/src/prime_rl/utils/metrics_server.py @@ -143,8 +143,16 @@ def __init__(self, config: "MetricsServerConfig"): ["run"], registry=self._registry, ) + # Packer metrics + self._packer_buffer_length = Gauge( + "trainer_packer_buffer_length", "Number of samples in packer buffer", ["run"], registry=self._registry + ) + self._packer_round_robin_position = Gauge( + "trainer_packer_round_robin_position", "Current round-robin index in packer", registry=self._registry + ) # Track known run labels for cleanup self._known_runs: set[str] = set() + self._known_packer_runs: set[str] = set() def _make_handler(self) -> type[BaseHTTPRequestHandler]: """Create handler with /metrics and /health endpoints.""" @@ -256,3 +264,21 @@ def update_runs( self._run_ready.labels(run=run.run_id).set(1 if run.ready else 0) self._known_runs = current_runs + + def update_packer(self, buffer_lengths: dict[str, int], round_robin_position: int) -> None: + """Update packer buffer metrics. + + Args: + buffer_lengths: Mapping of run_id to number of buffered samples + round_robin_position: Current round-robin index + """ + self._packer_round_robin_position.set(round_robin_position) + + current_runs = set(buffer_lengths.keys()) + for run_id in self._known_packer_runs - current_runs: + self._packer_buffer_length.remove(run_id) + + for run_id, length in buffer_lengths.items(): + self._packer_buffer_length.labels(run=run_id).set(length) + + self._known_packer_runs = current_runs