Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion docs/advance/rollout_skip.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,17 @@ Add these parameters to enable RolloutSkip:
actor_rollout_ref.rollout.skip.dump_dir=/path/to/rollout_dump
actor_rollout_ref.rollout.skip.max_dump_step=10

# Optional: dump only on selected training steps (1-based), e.g. steps 1, 2, and 5:
actor_rollout_ref.rollout.skip.dump_steps=[1,2,5]


Configuration Parameters
------------------------
- **skip.enable**: Enable or disable RolloutSkip.
- **skip.dump_dir**: Root directory to save cached rollout data.
- **skip.max_dump_step**: Maximum number of steps to cache.
- **skip.max_dump_step**: When ``dump_steps`` is unset or empty, dump/load while ``train_step <= max_dump_step``.
- **skip.dump_steps**: Optional explicit list of 1-based steps to dump/load. If non-empty, only those steps match; otherwise the ``max_dump_step`` window applies. Use null or ``[]`` for default behavior.
- **skip.action**: Applies on **non-dump** steps only. ``cache`` — always generate. ``repeat`` — reuse rollout files in round-robin over ``genstep_*`` dirs that were saved on earlier dump steps. ``repeat_last`` — reuse only the latest such dir. On dump steps, RolloutSkip always tries disk load first, then generate+dump if needed.


Cached Directory Structure
Expand Down
23 changes: 23 additions & 0 deletions tests/utils/test_rollout_skip_on_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,3 +302,26 @@ def test_rollout_with_CACHE_with_RESUME(self, mock_rollout_wg, step, capsys):
# * Final
skip.record(new_batch, step + resume_more_step + 1, None) # train_step start from 1
rollout_wg.generate_sequences(MagicMock())


class TestDumpAtTrainSteps:
"""rollout.skip.dump_steps overrides the max_dump_step window when non-empty."""

def test_custom_steps_only(self, mock_rollout_wg):
config, rollout_wg, _ = mock_rollout_wg
config.actor_rollout_ref.rollout.skip.dump_steps = [1, 2, 5]
config.actor_rollout_ref.rollout.skip.max_dump_step = 100
skip = RolloutSkip(config, rollout_wg)
for t, expected in [(1, True), (2, True), (3, False), (5, True), (6, False)]:
skip.curr_train_step = t
assert skip.is_dump_step is expected, t

def test_empty_list_uses_max_dump_step(self, mock_rollout_wg):
config, rollout_wg, _ = mock_rollout_wg
config.actor_rollout_ref.rollout.skip.dump_steps = []
config.actor_rollout_ref.rollout.skip.max_dump_step = 3
skip = RolloutSkip(config, rollout_wg)
skip.curr_train_step = 2
assert skip.is_dump_step
skip.curr_train_step = 4
assert not skip.is_dump_step
1 change: 1 addition & 0 deletions verl/trainer/config/_generated_diffusion_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ actor_rollout_ref:
enable: false
dump_dir: ~/.verl/rollout_dump
max_dump_step: 1
dump_steps: null
action: cache
skip_tokenizer_init: true
enable_rollout_routing_replay: false
Expand Down
1 change: 1 addition & 0 deletions verl/trainer/config/_generated_ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ actor_rollout_ref:
enable: false
dump_dir: ~/.verl/rollout_dump
max_dump_step: 1
dump_steps: null
action: cache
skip_tokenizer_init: true
enable_rollout_routing_replay: false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ actor_rollout_ref:
enable: false
dump_dir: ~/.verl/rollout_dump
max_dump_step: 1
dump_steps: null
action: cache
skip_tokenizer_init: true
enable_rollout_routing_replay: false
Expand Down
1 change: 1 addition & 0 deletions verl/trainer/config/_generated_ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ actor_rollout_ref:
enable: false
dump_dir: ~/.verl/rollout_dump
max_dump_step: 1
dump_steps: null
action: cache
skip_tokenizer_init: true
enable_rollout_routing_replay: false
Expand Down
1 change: 1 addition & 0 deletions verl/trainer/config/_generated_ppo_veomni_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ actor_rollout_ref:
enable: false
dump_dir: ~/.verl/rollout_dump
max_dump_step: 1
dump_steps: null
action: cache
skip_tokenizer_init: true
enable_rollout_routing_replay: false
Expand Down
12 changes: 8 additions & 4 deletions verl/trainer/config/rollout/rollout.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -317,10 +317,14 @@ skip:
# Number of training steps (from start) for which dump/load is active. Steps with train_step <= max_dump_step are "dump steps" (try load first, then generate and dump if missing).
max_dump_step: 1

# Action when beyond dump steps (gen_step > max_dump_step):
# - "cache": If dumped data exists for current step, use it; otherwise generate and dump.
# - "repeat": Reuse dumped data in a round-robin over the first max_dump_step steps.
# - "repeat_last": Reuse the last dumped step's data.
# Optional: only dump/load on these 1-based steps, e.g. [1, 2, 5]. If null or [], use max_dump_step above.
dump_steps: null

# When the current training step is NOT a dump step (see dump_steps / max_dump_step and ``is_dump_step``):
# - "cache": Always run real generation (no reuse of rollout dumps on these steps).
# - "repeat": Round-robin load from genstep_* dirs that were written on earlier dump steps (``list_dumped_steps``).
# - "repeat_last": Load only from the most recently dumped genstep_*.
# On dump steps, behavior is always: try load existing dump for this gen_step, else generate and save.
action: "cache"

# Whether to skip tokenizer initialization for rollout engine
Expand Down
65 changes: 55 additions & 10 deletions verl/utils/rollout_skip.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import warnings
from enum import Enum
from pathlib import Path
from typing import Any, Callable
Expand Down Expand Up @@ -66,15 +67,20 @@ def _find_last_gen_step_for_train_step(step_file: Path, target_train_step: int)


class SkipAction(Enum):
CACHE = "cache" # cache the sample. If dump_date is found, use it. If not found, dump it.
REPEAT = "repeat" # Repeat the sample when gen_step reach skip.max_dump_step
REPEAT_LAST = "repeat_last" # Repeat the last sample when gen_step reach skip.max_dump_step
# Used only when ``is_dump_step`` is False (see ``is_dump_step``). On dump steps, the wrapper always
# tries load-from-disk first, then generate+dump if missing.
CACHE = "cache" # Always run real generation; no reloading from rollout dump dirs.
REPEAT = (
"repeat" # Round-robin reload from ``genstep_*`` dirs recorded in ``list_dumped_steps`` (from earlier dumps).
)
REPEAT_LAST = "repeat_last" # Reload only from the last dumped ``genstep_*`` (``list_dumped_steps[-1]``).


class RolloutSkip:
"""
RolloutSkip skips sequence generation during rollout by attempting to load previously dumped data.
If no dumped data is found, it generates new sequences and saves them to disk.
RolloutSkip can reuse disk-cached rollout batches: on **dump steps** (see ``is_dump_step``), it tries to
load ``genstep_*`` data first and only generates+writes if missing. On **non-dump** steps, behavior is
controlled by ``skip.action`` (``cache`` / ``repeat`` / ``repeat_last``); see ``SkipAction``.

Args:
config: The configuration object containing rollout settings.
Expand Down Expand Up @@ -118,7 +124,33 @@ def __init__(self, config, rollout_wg) -> None:
self.action = _get_skip_attr(self.skip_config, "action", SkipAction.REPEAT)
self.action = SkipAction(self.action)

if self.max_dump_step <= 0:
raw_steps = _get_skip_attr(self.skip_config, "dump_steps", None)
self._dump_step_set: frozenset[int] | None = None
if raw_steps is not None:
if isinstance(raw_steps, str):
try:
raw_steps = json.loads(raw_steps)
except json.JSONDecodeError:
warnings.warn(
f"{self.print_mark}Could not parse 'dump_steps' string: {raw_steps!r}. "
"Falling back to max_dump_step.",
stacklevel=2,
)
raw_steps = None
if raw_steps:
try:
step_list = [int(x) for x in raw_steps]
if step_list:
self._dump_step_set = frozenset(step_list)
except (ValueError, TypeError):
warnings.warn(
f"{self.print_mark}'dump_steps' must be a list of integers, got: {raw_steps!r}. "
"Falling back to max_dump_step.",
stacklevel=2,
)

_use_max_dump_window = self._dump_step_set is None
if _use_max_dump_window and self.max_dump_step <= 0:
assert self.action in [SkipAction.CACHE]

self._create_dump_path()
Expand All @@ -133,10 +165,23 @@ def is_active(self) -> bool:
@property
def is_dump_step(self) -> bool:
"""
Determine if the current step is a dump step based on the configured dump interval.
If train_step is given, it follows the train_step, otherwise it follows the gen_step.
Whether the current training step should run the dump/load path (try disk cache before generating).

**Explicit list** — If ``dump_steps`` is non-empty, only steps whose index is in that list
are dump steps. Step indices are 1-based (first update is 1), same counter as used for
``max_dump_step`` below—not a different numbering scheme.

**Window** — If ``dump_steps`` is null or empty, dump/load while ``curr_train_step <= max_dump_step``.

``curr_train_step`` follows ``record(..., global_steps=...)`` when the trainer passes it;
otherwise it is advanced by ``step()`` each rollout.
"""
return self.is_active and self.curr_train_step <= self.max_dump_step
if not self.is_active:
return False
t = self.curr_train_step
if self._dump_step_set is not None:
return t in self._dump_step_set
return t <= self.max_dump_step

@property
def num_dumped_step(self) -> int:
Expand Down Expand Up @@ -238,7 +283,7 @@ def record(
if found is not None:
last_train_step, last_gen_step = found
if last_train_step + 1 != global_steps:
print(f"{self.print_mark}\033[31mWarning: Train step not continues.\033[0m")
print(f"{self.print_mark}\033[31mWarning: Train step not continuous.\033[0m")
self.__gen_offset_step = last_gen_step
except Exception as e:
print(
Expand Down
4 changes: 4 additions & 0 deletions verl/workers/config/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ class SkipConfig(BaseConfig):
enable: bool = False
dump_dir: str = "~/.verl/rollout_dump"
max_dump_step: int = 1
# If set (non-empty), dump/load only on these 1-based steps.
# If null or empty, fall back to the max_dump_step window.
dump_steps: Optional[list[int]] = None
# When not a dump step: cache=always generate; repeat=round-robin from dumped gen_steps; repeat_last=last dump only.
action: str = "cache" # cache | repeat | repeat_last

def get(self, key: str, default=None):
Expand Down