Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/prime_rl/trainer/rl/packer.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,13 @@ def _update_run_progress(self, run_idx: int, num_samples: int, num_tokens: int)
self.multi_run_manager.progress[run_idx].total_tokens += num_tokens
self.multi_run_manager.progress[run_idx].total_samples += num_samples

def get_buffer_stats(self) -> tuple[dict[str, int], int]:
"""Return per-run buffer lengths and round-robin position for metrics."""
buffer_lens = {
self.multi_run_manager.idx_2_id[idx]: len(self.buffers[idx]) for idx in self.multi_run_manager.used_idxs
}
return buffer_lens, self._round_robin_position

def pack(self):
"""Pack samples from buffers using round-robin fair scheduling."""
self._get_batch()
Expand Down
4 changes: 4 additions & 0 deletions src/prime_rl/trainer/rl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from prime_rl.trainer.scheduler import setup_scheduler, setup_multi_scheduler
from prime_rl.configs.trainer import TrainerConfig
from prime_rl.trainer.rl.data import DataLoader, FakeDataLoader
from prime_rl.trainer.rl.packer import MultiPacker
from prime_rl.utils.cp import (
gather_for_cp,
gather_for_cp_wo_grad,
Expand Down Expand Up @@ -596,6 +597,9 @@ def load_run_checkpoint(_optimizer, idx: int) -> None:
runs_max=multi_run_manager.max_runs,
run_stats=run_stats,
)
if isinstance(dataloader, DataLoader) and isinstance(dataloader.packer, MultiPacker):
buffer_lengths, rr_pos = dataloader.packer.get_buffer_stats()
metrics_server.update_packer(buffer_lengths, rr_pos)

progress.step += 1
is_first_step = False
Expand Down
26 changes: 26 additions & 0 deletions src/prime_rl/utils/metrics_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,16 @@ def __init__(self, config: "MetricsServerConfig"):
["run"],
registry=self._registry,
)
# Packer metrics
self._packer_buffer_length = Gauge(
"trainer_packer_buffer_length", "Number of samples in packer buffer", ["run"], registry=self._registry
)
self._packer_round_robin_position = Gauge(
"trainer_packer_round_robin_position", "Current round-robin index in packer", registry=self._registry
)
# Track known run labels for cleanup
self._known_runs: set[str] = set()
self._known_packer_runs: set[str] = set()

def _make_handler(self) -> type[BaseHTTPRequestHandler]:
"""Create handler with /metrics and /health endpoints."""
Expand Down Expand Up @@ -256,3 +264,21 @@ def update_runs(
self._run_ready.labels(run=run.run_id).set(1 if run.ready else 0)

self._known_runs = current_runs

def update_packer(self, buffer_lengths: dict[str, int], round_robin_position: int) -> None:
"""Update packer buffer metrics.

Args:
buffer_lengths: Mapping of run_id to number of buffered samples
round_robin_position: Current round-robin index
"""
self._packer_round_robin_position.set(round_robin_position)

current_runs = set(buffer_lengths.keys())
for run_id in self._known_packer_runs - current_runs:
self._packer_buffer_length.remove(run_id)

for run_id, length in buffer_lengths.items():
self._packer_buffer_length.labels(run=run_id).set(length)

self._known_packer_runs = current_runs
Loading