Add Trackio rollout trace logging#1697
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for Trackio logging, allowing rollout samples to be logged as traces. Changes include new configuration options in TrackioTraceConfig, the addition of the trackio dependency, and the implementation of _log_trackio_rollout_traces within the trainer to handle trace record construction and logging. Feedback focuses on two main issues: first, the logging call in fully_async_trainer.py is currently inside a loop, which could lead to excessive logging and performance degradation; it should be moved outside the loop using a merged input batch. Second, the logic for mapping responses to prompts in trainer.py is flawed when multiple samples are generated per prompt, and should be updated to use a calculation based on the number of samples per prompt.
| self._log_trackio_rollout_traces( | ||
| cur_generated_output_group.generator_input, | ||
| cur_generated_output_group.generator_output, | ||
| ) |
There was a problem hiding this comment.
Logging traces inside the loop will result in up to max_traces_per_step traces being logged per prompt group in the mini-batch, rather than per training step. For a typical mini-batch size (e.g., 256), this will log thousands of traces, likely exceeding the intended limit and impacting logging performance/storage.
It is better to move the logging call after the loop and pass a merged input batch to log a single set of traces for the entire training step.
generator_inputs = []
for cur_generated_output_group in cur_generation_group_mini_batch:
cur_staleness = self.global_step - cur_generated_output_group.global_step_when_scheduled
stalenesses.append(cur_staleness)
generator_outputs.append(cur_generated_output_group.generator_output)
generator_inputs.append(cur_generated_output_group.generator_input)
# NOTE(Charlie): for step-wise training each group can contain a variable number of entries
# (n_samples_per_prompt * variable turns_per_trajectory), so the uid fanout is per-group.
group_size = len(cur_generated_output_group.generator_output["response_ids"])
uids.extend([cur_generated_output_group.uid] * group_size)
# Check staleness violation.
if cur_staleness > self.max_staleness_steps:
# ... (rest of the loop logic) ...
generator_output = concatenate_generator_outputs(
generator_outputs, step_wise=self.cfg.generator.step_wise_trajectories
)
# Construct a merged input batch for logging traces
merged_input = {
"prompts": [g.generator_input["prompts"][0] for g in cur_generation_group_mini_batch],
"trajectory_ids": [g.generator_input["trajectory_ids"][0] for g in cur_generation_group_mini_batch if g.generator_input.get("trajectory_ids")],
"batch_metadata": cur_generation_group_mini_batch[0].generator_input.get("batch_metadata")
}
self._log_trackio_rollout_traces(merged_input, generator_output)There was a problem hiding this comment.
Addressed in the latest push. Fully async now collects the generated groups first, concatenates the generator outputs, merges the corresponding generator inputs, and logs Trackio traces once for the training mini-batch so max_traces_per_step is enforced per step rather than per prompt group.
|
|
||
| num_traces = min(max_traces, len(generator_output["response_ids"])) | ||
| for sample_index in range(num_traces): | ||
| input_idx = sample_index if sample_index < len(prompts) else len(prompts) - 1 |
There was a problem hiding this comment.
The fallback logic for calculating input_idx is incorrect when n_samples_per_prompt > 1. If trajectory_ids are missing or fail to match, this logic will incorrectly map multiple responses to the same prompt or exceed the bounds of the prompts list.
Since the generator output layout is typically blocked (all samples for prompt 0, then all for prompt 1, etc.), the index should be calculated based on the number of samples per prompt.
| input_idx = sample_index if sample_index < len(prompts) else len(prompts) - 1 | |
| samples_per_prompt = len(generator_output["response_ids"]) // len(prompts) | |
| input_idx = sample_index // samples_per_prompt |
There was a problem hiding this comment.
Also addressed while updating the high-priority item. The fallback prompt mapping now computes a blocked sample-to-prompt index from the number of responses per prompt, with a bound check, and there is a focused test for the missing-trajectory-id case.
68c3dd0 to
83f8176
Compare
Hi folks! This PR adds trace logging via Trackio, the free, local-first experiment tracking library from Hugging Face 🤗
This PR follows SkyRL's existing metric logging and rollout generation patterns. Specifically, I did this:
trackioas a supported training logger backendtrainer.trackio_trace.max_traces_per_stepandtrainer.trackio_trace.trace_keytrackio.Tracerecords with step, split, reward, stop reason, and trajectory metadataHere's what it looks like:
AI assistance was used to prepare this PR.