Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions .claude/docs/contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,27 @@ Don't:
# Example (full epoch over the 5000 rows):
# bash examples/train/sft/run_sft_megatron_apigen_mt.sh num_epochs=1 num_steps=null

```

## Error messages

The same holds true for error messages:

Do:

```python
if self._callback_handler.callbacks:
raise NotImplementedError(
"Callbacks are not yet supported by `FullyAsyncRayPPOTrainer`. "
)
```

Don't:

```python
if self._callback_handler.callbacks:
raise NotImplementedError(
"Callbacks are not yet supported by `FullyAsyncRayPPOTrainer`. "
"Track in a follow-up; the sync RayPPOTrainer and SFTTrainer do support them."
)
```
49 changes: 49 additions & 0 deletions examples/train/callbacks/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Training callbacks

Demonstrates the `TrainingCallback` API by adding a `PerplexityLogger`
callback to the SFT trainer. The same pattern works for the RL trainer
(`RayPPOTrainer` accepts a `callbacks=` constructor arg).

## Files

- `perplexity_logger.py` — example callback that logs `train/perplexity` on
every step, piggy-backing on the trainer's own wandb step.
- `main_sft_with_callbacks.py` — custom entrypoint that constructs
`SFTTrainer(..., callbacks=[PerplexityLogger()])`.
- `run_sft_with_callbacks.sh` — launcher; mirrors `examples/train/sft/run_sft_fsdp.sh`
but runs through the custom entrypoint.

## Run

```bash
bash examples/train/callbacks/run_sft_with_callbacks.sh
```

## Writing your own callback

Subclass `TrainingCallback` and override the events you care about. Every
event receives the same three arguments: `(trainer, callback_input, control)`.

```python
from skyrl.train.utils.callbacks import TrainingCallback

class LogGradNorm(TrainingCallback):
def on_step_end(self, trainer, callback_input, control):
gn = (callback_input.metrics or {}).get("grad_norm")
if gn is None:
return
trainer.tracker.log(
{"diag/grad_norm": gn},
step=callback_input.global_step,
commit=False,
)
```

`callback_input` carries the loop counters plus the per-event payload that
applies (`batch` on step events, `metrics` on step/eval end, `logs` on
`on_log`, `ckpt_path` on `on_save`). Anything else — `tokenizer`, `dispatch`,
`tracker`, `cfg` — is reached through `trainer.*`.

Set `control.should_save` / `should_evaluate` to request a checkpoint save
or an eval pass at the end of the current step; the trainer honors and
resets those flags.
40 changes: 40 additions & 0 deletions examples/train/callbacks/main_sft_with_callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""SFT entrypoint demonstrating the PerplexityLogger callback.

Usage:
bash examples/train/callbacks/run_sft_with_callbacks.sh
"""

import sys

import ray

from examples.train.callbacks.perplexity_logger import PerplexityLogger
from skyrl.train.config import SkyRLTrainConfig
from skyrl.train.config.sft_config import (
SFTConfig,
build_skyrl_config_for_sft,
validate_sft_cfg,
)
from skyrl.train.sft_trainer import SFTTrainer
from skyrl.train.utils.utils import initialize_ray


@ray.remote(num_cpus=1)
def sft_entrypoint(cfg: SFTConfig, skyrl_cfg: SkyRLTrainConfig):
callbacks = [PerplexityLogger()]
trainer = SFTTrainer(cfg, skyrl_cfg=skyrl_cfg, callbacks=callbacks)
trainer.setup()
trainer.train()
trainer.shutdown()


def main():
cfg = SFTConfig.from_cli_overrides(sys.argv[1:])
validate_sft_cfg(cfg)
skyrl_cfg = build_skyrl_config_for_sft(cfg)
initialize_ray(skyrl_cfg)
ray.get(sft_entrypoint.remote(cfg, skyrl_cfg))


if __name__ == "__main__":
main()
31 changes: 31 additions & 0 deletions examples/train/callbacks/perplexity_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""PerplexityLogger callback — logs train perplexity alongside the trainer's
own metrics by writing into the same wandb step via ``trainer.tracker``.
"""

import math

from skyrl.train.utils.callbacks import (
CallbackInput,
TrainingCallback,
TrainingControl,
)


class PerplexityLogger(TrainingCallback):
"""Log ``train/perplexity = exp(loss)`` on every step.

Uses ``commit=False`` so the perplexity value is bundled into the same
wandb step the trainer commits on its own. The ``min(loss, 20)`` cap
keeps ``exp`` from overflowing on the first few unstable steps.
"""

def on_step_end(self, trainer, callback_input: CallbackInput, control: TrainingControl) -> None:
metrics = callback_input.metrics or {}
loss = metrics.get("loss")
if loss is None or math.isnan(loss):
return
trainer.tracker.log(
{"train/perplexity": math.exp(min(loss, 20))},
step=callback_input.global_step,
commit=False,
)
39 changes: 39 additions & 0 deletions examples/train/callbacks/run_sft_with_callbacks.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/bin/bash
set -x

# SFT training with an EarlyStopping callback on eval_loss.
#
# Usage:
# bash examples/train/callbacks/run_sft_with_callbacks.sh [extra overrides...]

uv run --isolated --extra fsdp \
-m examples.train.callbacks.main_sft_with_callbacks \
strategy=fsdp \
model.path=Qwen/Qwen2.5-0.5B-Instruct \
dataset_name=yahma/alpaca-cleaned \
dataset_split="train[:200]" \
eval_dataset_name=yahma/alpaca-cleaned \
eval_dataset_split="train[200:240]" \
messages_key=messages \
max_length=512 \
num_steps=40 \
eval_interval=5 \
eval_before_train=true \
batch_size=4 \
micro_train_batch_size_per_gpu=2 \
use_sample_packing=true \
seed=42 \
optimizer_config.lr=1e-6 \
optimizer_config.weight_decay=1e-2 \
optimizer_config.max_grad_norm=1.0 \
optimizer_config.num_warmup_steps=0 \
optimizer_config.scheduler=constant_with_warmup \
placement.num_nodes=1 \
placement.num_gpus_per_node=1 \
fsdp_config.cpu_offload=false \
fsdp_config.reshard_after_forward=true \
logger=console \
project_name=skyrl_sft_callbacks \
run_name=skyrl_sft_callbacks_run \
ckpt_path="" \
"$@"
4 changes: 2 additions & 2 deletions skyrl/train/config/sft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ def from_cli_overrides(cls, args: Union[List[str], dict]) -> "SFTConfig":
logger: str = "console" # "console" or "wandb"
project_name: str = "skyrl_sft"
run_name: str = "skyrl_sft_run"
ckpt_path: str = "" # empty string = no checkpointing
ckpt_interval: int = 0
ckpt_path: str = ""
ckpt_interval: int = 0 # <= 0 -> no checkpointing
max_ckpts_to_keep: int = -1
"""-1 to keep all checkpoints, N to keep only the last N."""
resume_from: str = "" # "" = no resume, "latest" = latest checkpoint, or path to global_step_N dir
Expand Down
18 changes: 15 additions & 3 deletions skyrl/train/fully_async_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,14 @@ def __init__(self, *args, **kwargs):
# Initialize base trainer
super().__init__(*args, **kwargs)

# Callbacks aren't wired into FullyAsyncRayPPOTrainer.train() yet — fail
# fast rather than silently dropping events
if self._callback_handler.callbacks:
raise NotImplementedError(
"Callbacks are not yet supported by FullyAsyncRayPPOTrainer. "
"Track in a follow-up; the sync RayPPOTrainer and SFTTrainer do support them."
)

# Some async-specific validations
assert (
self.cfg.trainer.train_batch_size == self.cfg.trainer.policy_mini_batch_size
Expand All @@ -312,6 +320,9 @@ def __init__(self, *args, **kwargs):
max_staleness_steps=self.max_staleness_steps,
)

def add_callback(self, callback):
raise NotImplementedError("Callbacks are not yet supported by FullyAsyncRayPPOTrainer. ")

def _build_train_dataloader_and_compute_training_steps(self):
"""
Overrides to build dataloader for fully async training. See `_AsyncDataloader` for more details.
Expand Down Expand Up @@ -666,26 +677,27 @@ def convert_generation_group_mini_batch_to_training_input(

return self.convert_to_training_input(generator_output, uids)

def save_checkpoints(self):
def save_checkpoints(self) -> str:
"""
Extend base checkpointing by recording consumed UIDs for fully-async training.

Otherwise, when resuming, there is no way to know which data has been trained on.
Returns the checkpoint folder path (forwarded from the base implementation).
"""
consumed_uids_list = (
self.async_train_dataloader.get_consumed_uids_list()
) # read first to prevent race condition
# The base method will save the model, dataloader path, trainer_state, and latest_ckpt_global_step.txt.
super().save_checkpoints()
global_step_folder = super().save_checkpoints()
# In addition, we need to save the consumed UIDs -- the data that we have already trained on.
global_step_folder = os.path.join(self.cfg.trainer.ckpt_path, f"global_step_{self.global_step}")
fully_async_state_path = os.path.join(global_step_folder, "fully_async_state.pt")
fully_async_state = {
"consumed_uids": consumed_uids_list,
}
with io.open_file(fully_async_state_path, "wb") as f:
torch.save(fully_async_state, f)
logger.info(f"Saved fully-async state to {fully_async_state_path}")
return global_step_folder

def load_checkpoints(self) -> Tuple[int, str, Set[str]]:
"""
Expand Down
Loading
Loading