Skip to content

[feat] Support training callbacks in SkyRL#1696

Draft
SumanthRH wants to merge 11 commits into
mainfrom
sft-callback
Draft

[feat] Support training callbacks in SkyRL#1696
SumanthRH wants to merge 11 commits into
mainfrom
sft-callback

Conversation

@SumanthRH
Copy link
Copy Markdown
Member

@SumanthRH SumanthRH commented May 22, 2026

What does this PR do?

Adds training callbacks in SkyRL for the SFT and RL training.

Training callbacks are a standard pattern to handle different training events as well as inject custom training behaviour (ex: early stopping). Historically, SkyRL has not prioritized such customization behaviours and users would instead subclass the trainer for this.

Design

The design for training callbacks is heavily inspired by HuggingFace's Trainer.

Core Types

  1. A CallbackInput dataclass
@dataclass
class CallbackInput:
    """Training state passed to every callback event.

    Read-only from the callback's perspective. The trainer refreshes
   this before each event dispatch.

    Always-populated fields are the loop counters at the top. The remaining
    fields are populated only when relevant to the firing event (see table
    below) and are `None` otherwise — callbacks should null-check the fields
    they care about.
    """
    # Always populated
    global_step: int
    epoch: int
    total_steps: int
    steps_per_epoch: int

    # Step events (`on_step_start`, `on_step_end`)
    batch: Optional["TrainingInputBatch"] = None

    # Step end + eval end
    metrics: Optional[Dict[str, Any]] = None

    # `on_log` only — the dict the trainer is about to commit to the tracker.
    # Callbacks may mutate it in place to add fields.
    logs: Optional[Dict[str, Any]] = None

    # `on_save` only
    ckpt_path: Optional[str] = None
  1. A TrainingControl dataclass with mutable flags that callbacks can set to influence the trainer.
@dataclass
class TrainingControl:
    """Mutable flags callbacks can set to influence the trainer.

    The trainer checks these flags after each callback dispatch:
      - should_save:           force a checkpoint at end of current step
      - should_evaluate:       run eval at end of current step
      - should_training_stop:  exit the training loop after current step
    Flags are reset by the trainer after they have been honored.
    """
    should_save: bool = False
    should_evaluate: bool = False
    should_training_stop: bool = False
  1. A TrainingCallback class, allowing users to define callbacks on training events
class TrainingCallback:
    """Base class. Override the events you care about; others are no-ops.

    Every event receives the same three arguments:
      - trainer:  the SFTTrainer or RayPPOTrainer instance. The stable surface
                  is `trainer.cfg`, `trainer.tracker`, `trainer.tokenizer`,
                  `trainer.dispatch`, `trainer.global_step`.
      - callback_input: a `CallbackInput` snapshot. Which fields are populated
                  depends on the event (see table below).
      - control:  a mutable `TrainingControl` the callback can set.
    """

    def on_train_start(self, trainer, callback_input, control): ...
    def on_train_end(self, trainer, callback_input, control): ...

    def on_epoch_start(self, trainer, callback_input, control): ...
    def on_epoch_end(self, trainer, callback_input, control): ...

    def on_step_start(self, trainer, callback_input, control): ...
    def on_step_end(self, trainer, callback_input, control): ...

    def on_eval_start(self, trainer, callback_input, control): ...
    def on_eval_end(self, trainer, callback_input, control): ...

    def on_save(self, trainer, callback_input, control): ...
    def on_log(self, trainer, callback_input, control): ...

Fields populated per event

Event callback_input.batch callback_input.metrics callback_input.logs callback_input.ckpt_path
on_train_start
on_train_end
on_epoch_start
on_epoch_end
on_step_start
on_step_end
on_eval_start
on_eval_end
on_save
on_log

Callback Handler

This PR also defines a CallbackHandler for handling a list of training callbacks. this is the primary interface for trainers to interact with callbacks

class CallbackHandler(TrainingCallback):
    def __init__(self, callbacks): self.callbacks = list(callbacks)

    def _dispatch(self, name, trainer, callback_input, control):
        for cb in self.callbacks:
            getattr(cb, name)(trainer, callback_input, control)

    def on_step_end(self, trainer, callback_input, control):
        self._dispatch("on_step_end", trainer, callback_input, control)
    # ... etc

Registration

class SFTTrainer:
    def __init__(self, cfg, skyrl_cfg=None, callbacks=None):
        self._callback_handler = CallbackHandler(callbacks)
        ...

    def add_callback(self, callback: TrainingCallback):
        self._callback_handler.add(callback)

Constructor is the default path for adding callbacks, and add_callback is a convenient helper for adding callbacks post-init.

Example Usage

An example is provided in examples/train/callbacks , adding the common early stopping callback:

class EarlyStopping(TrainingCallback):
    def __init__(self, patience: int = 3, min_delta: float = 0.0):
        self.patience = patience
        self.min_delta = min_delta
        self.best = float("inf")
        self.stale = 0

    def on_eval_end(self, trainer, callback_input, control):
        eval_loss = (callback_input.metrics or {}).get("eval_loss")
        if eval_loss is None: return
        if eval_loss + self.min_delta < self.best:
            self.best = eval_loss
            self.stale = 0
        else:
            self.stale += 1
            if self.stale >= self.patience:
                control.should_training_stop = True

Test Plan

  • An E2E test script tests/backends/skyrl_train/gpu/gpu_ci/test_sft_callbacks.py which starts a dummy SFT run and sanity checks that the relevant callbacks have been triggered
  • Unit tests: Most of the non-trivial code in this PR is in the trainer with control logic for triggering callbacks, which is why I've chosen to only add a E2E test. Most individual components added in the PR are base classes and trivial dataclasses.

Limitations

We do not add support for callbacks to the fully async trainer in this PR, and this will be done as a follow-up.

TODO:

  • Add an E2E test script for RL trainer with callbacks
  • Cleanup test

SumanthRH and others added 11 commits May 19, 2026 23:00
Introduces an HF/Lightning-style callback API so users can hook training
events without subclassing the trainer. Surfaces ten events (train, epoch,
step, eval × start/end, plus on_save, on_log), passes a unified
CallbackInput with always-populated counters + per-event optional fields
(batch / metrics / logs / ckpt_path), and exposes a mutable TrainingControl
for early-stop / force-save / force-eval requests.

Wired into both SFTTrainer and RayPPOTrainer. Callbacks register via the
callbacks= constructor arg or add_callback() post-construction.
fully_async_trainer.save_checkpoints() forwards the base path so on_save
receives a non-None ckpt_path. The baseline eval_before_train pass fires
on_eval_start / on_eval_end / on_log on both trainers, and RL aligns
post-loop global_step before firing on_save / on_train_end.

Ships an example at examples/train/callbacks/ with an EarlyStopping
callback, a custom Ray entrypoint, and a launcher script.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Signed-off-by: SumanthRH <[email protected]>
Signed-off-by: SumanthRH <[email protected]>
Signed-off-by: SumanthRH <[email protected]>
x
Signed-off-by: SumanthRH <[email protected]>
x
Signed-off-by: SumanthRH <[email protected]>
x
Signed-off-by: SumanthRH <[email protected]>
Signed-off-by: SumanthRH <[email protected]>
Mirrors the structure of test_rl_callbacks.py: mocks dispatch + tokenizer,
skips trainer.setup(), and exercises only the orchestration in
SFTTrainer.train(). Same event-sequence + payload assertions as before,
but no real FSDP / model load -> runs without GPUs.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Signed-off-by: SumanthRH <[email protected]>
Signed-off-by: SumanthRH <[email protected]>

# Conflicts:
#	skyrl/train/sft_trainer.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant