Skip to content

Add Trackio rollout trace logging#1697

Open
abidlabs wants to merge 1 commit into
NovaSky-AI:mainfrom
abidlabs:add-trackio-rollout-traces
Open

Add Trackio rollout trace logging#1697
abidlabs wants to merge 1 commit into
NovaSky-AI:mainfrom
abidlabs:add-trackio-rollout-traces

Conversation

@abidlabs
Copy link
Copy Markdown

@abidlabs abidlabs commented May 22, 2026

Hi folks! This PR adds trace logging via Trackio, the free, local-first experiment tracking library from Hugging Face 🤗

This PR follows SkyRL's existing metric logging and rollout generation patterns. Specifically, I did this:

  • added trackio as a supported training logger backend
  • added configurable rollout trace logging with trainer.trackio_trace.max_traces_per_step and trainer.trackio_trace.trace_key
  • logged rollout prompt/response conversations as trackio.Trace records with step, split, reward, stop reason, and trajectory metadata
  • covered both synchronous and fully async rollout generation paths
  • documented the Trackio logger and trace settings
  • added a mocked Trackio rollout trace test

Here's what it looks like:

image

AI assistance was used to prepare this PR.

@abidlabs abidlabs marked this pull request as ready for review May 22, 2026 23:37
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Trackio logging, allowing rollout samples to be logged as traces. Changes include new configuration options in TrackioTraceConfig, the addition of the trackio dependency, and the implementation of _log_trackio_rollout_traces within the trainer to handle trace record construction and logging. Feedback focuses on two main issues: first, the logging call in fully_async_trainer.py is currently inside a loop, which could lead to excessive logging and performance degradation; it should be moved outside the loop using a merged input batch. Second, the logic for mapping responses to prompts in trainer.py is flawed when multiple samples are generated per prompt, and should be updated to use a calculation based on the number of samples per prompt.

Comment thread skyrl/train/fully_async_trainer.py Outdated
Comment on lines +626 to +629
self._log_trackio_rollout_traces(
cur_generated_output_group.generator_input,
cur_generated_output_group.generator_output,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Logging traces inside the loop will result in up to max_traces_per_step traces being logged per prompt group in the mini-batch, rather than per training step. For a typical mini-batch size (e.g., 256), this will log thousands of traces, likely exceeding the intended limit and impacting logging performance/storage.

It is better to move the logging call after the loop and pass a merged input batch to log a single set of traces for the entire training step.

        generator_inputs = []
        for cur_generated_output_group in cur_generation_group_mini_batch:
            cur_staleness = self.global_step - cur_generated_output_group.global_step_when_scheduled
            stalenesses.append(cur_staleness)
            generator_outputs.append(cur_generated_output_group.generator_output)
            generator_inputs.append(cur_generated_output_group.generator_input)
            # NOTE(Charlie): for step-wise training each group can contain a variable number of entries
            # (n_samples_per_prompt * variable turns_per_trajectory), so the uid fanout is per-group.
            group_size = len(cur_generated_output_group.generator_output["response_ids"])
            uids.extend([cur_generated_output_group.uid] * group_size)

            # Check staleness violation.
            if cur_staleness > self.max_staleness_steps:
                # ... (rest of the loop logic) ...

        generator_output = concatenate_generator_outputs(
            generator_outputs, step_wise=self.cfg.generator.step_wise_trajectories
        )
        
        # Construct a merged input batch for logging traces
        merged_input = {
            "prompts": [g.generator_input["prompts"][0] for g in cur_generation_group_mini_batch],
            "trajectory_ids": [g.generator_input["trajectory_ids"][0] for g in cur_generation_group_mini_batch if g.generator_input.get("trajectory_ids")],
            "batch_metadata": cur_generation_group_mini_batch[0].generator_input.get("batch_metadata")
        }
        self._log_trackio_rollout_traces(merged_input, generator_output)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in the latest push. Fully async now collects the generated groups first, concatenates the generator outputs, merges the corresponding generator inputs, and logs Trackio traces once for the training mini-batch so max_traces_per_step is enforced per step rather than per prompt group.

Comment thread skyrl/train/trainer.py Outdated

num_traces = min(max_traces, len(generator_output["response_ids"]))
for sample_index in range(num_traces):
input_idx = sample_index if sample_index < len(prompts) else len(prompts) - 1
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The fallback logic for calculating input_idx is incorrect when n_samples_per_prompt > 1. If trajectory_ids are missing or fail to match, this logic will incorrectly map multiple responses to the same prompt or exceed the bounds of the prompts list.

Since the generator output layout is typically blocked (all samples for prompt 0, then all for prompt 1, etc.), the index should be calculated based on the number of samples per prompt.

Suggested change
input_idx = sample_index if sample_index < len(prompts) else len(prompts) - 1
samples_per_prompt = len(generator_output["response_ids"]) // len(prompts)
input_idx = sample_index // samples_per_prompt

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also addressed while updating the high-priority item. The fallback prompt mapping now computes a blocked sample-to-prompt index from the number of responses per prompt, with a bound check, and there is a focused test for the missing-trajectory-id case.

@abidlabs abidlabs force-pushed the add-trackio-rollout-traces branch from 68c3dd0 to 83f8176 Compare May 22, 2026 23:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant