diff --git a/docs/advance/rollout_skip.rst b/docs/advance/rollout_skip.rst index 6ed661d2bad..43d8fd8d943 100644 --- a/docs/advance/rollout_skip.rst +++ b/docs/advance/rollout_skip.rst @@ -43,12 +43,17 @@ Add these parameters to enable RolloutSkip: actor_rollout_ref.rollout.skip.dump_dir=/path/to/rollout_dump actor_rollout_ref.rollout.skip.max_dump_step=10 + # Optional: dump only on selected training steps (1-based), e.g. steps 1, 2, and 5: + actor_rollout_ref.rollout.skip.dump_steps=[1,2,5] + Configuration Parameters ------------------------ - **skip.enable**: Enable or disable RolloutSkip. - **skip.dump_dir**: Root directory to save cached rollout data. -- **skip.max_dump_step**: Maximum number of steps to cache. +- **skip.max_dump_step**: When ``dump_steps`` is unset or empty, dump/load while ``train_step <= max_dump_step``. +- **skip.dump_steps**: Optional explicit list of 1-based steps to dump/load. If non-empty, only those steps match; otherwise the ``max_dump_step`` window applies. Use null or ``[]`` for default behavior. +- **skip.action**: Applies on **non-dump** steps only. ``cache`` — always generate. ``repeat`` — reuse rollout files in round-robin over ``genstep_*`` dirs that were saved on earlier dump steps. ``repeat_last`` — reuse only the latest such dir. On dump steps, RolloutSkip always tries disk load first, then generate+dump if needed. Cached Directory Structure diff --git a/tests/utils/test_rollout_skip_on_cpu.py b/tests/utils/test_rollout_skip_on_cpu.py index 84338fee9ba..a49f77ce554 100644 --- a/tests/utils/test_rollout_skip_on_cpu.py +++ b/tests/utils/test_rollout_skip_on_cpu.py @@ -302,3 +302,26 @@ def test_rollout_with_CACHE_with_RESUME(self, mock_rollout_wg, step, capsys): # * Final skip.record(new_batch, step + resume_more_step + 1, None) # train_step start from 1 rollout_wg.generate_sequences(MagicMock()) + + +class TestDumpAtTrainSteps: + """rollout.skip.dump_steps overrides the max_dump_step window when non-empty.""" + + def test_custom_steps_only(self, mock_rollout_wg): + config, rollout_wg, _ = mock_rollout_wg + config.actor_rollout_ref.rollout.skip.dump_steps = [1, 2, 5] + config.actor_rollout_ref.rollout.skip.max_dump_step = 100 + skip = RolloutSkip(config, rollout_wg) + for t, expected in [(1, True), (2, True), (3, False), (5, True), (6, False)]: + skip.curr_train_step = t + assert skip.is_dump_step is expected, t + + def test_empty_list_uses_max_dump_step(self, mock_rollout_wg): + config, rollout_wg, _ = mock_rollout_wg + config.actor_rollout_ref.rollout.skip.dump_steps = [] + config.actor_rollout_ref.rollout.skip.max_dump_step = 3 + skip = RolloutSkip(config, rollout_wg) + skip.curr_train_step = 2 + assert skip.is_dump_step + skip.curr_train_step = 4 + assert not skip.is_dump_step diff --git a/verl/trainer/config/_generated_diffusion_trainer.yaml b/verl/trainer/config/_generated_diffusion_trainer.yaml index 88220e2c63c..b0c45db571a 100644 --- a/verl/trainer/config/_generated_diffusion_trainer.yaml +++ b/verl/trainer/config/_generated_diffusion_trainer.yaml @@ -316,6 +316,7 @@ actor_rollout_ref: enable: false dump_dir: ~/.verl/rollout_dump max_dump_step: 1 + dump_steps: null action: cache skip_tokenizer_init: true enable_rollout_routing_replay: false diff --git a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml index 2fd3ee130cc..64d64c1d7f6 100644 --- a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml @@ -332,6 +332,7 @@ actor_rollout_ref: enable: false dump_dir: ~/.verl/rollout_dump max_dump_step: 1 + dump_steps: null action: cache skip_tokenizer_init: true enable_rollout_routing_replay: false diff --git a/verl/trainer/config/_generated_ppo_torchtitan_trainer.yaml b/verl/trainer/config/_generated_ppo_torchtitan_trainer.yaml index 002ce98fb60..db57b74fdc4 100644 --- a/verl/trainer/config/_generated_ppo_torchtitan_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_torchtitan_trainer.yaml @@ -303,6 +303,7 @@ actor_rollout_ref: enable: false dump_dir: ~/.verl/rollout_dump max_dump_step: 1 + dump_steps: null action: cache skip_tokenizer_init: true enable_rollout_routing_replay: false diff --git a/verl/trainer/config/_generated_ppo_trainer.yaml b/verl/trainer/config/_generated_ppo_trainer.yaml index fa1077eafa4..a38e28b5b3a 100644 --- a/verl/trainer/config/_generated_ppo_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_trainer.yaml @@ -312,6 +312,7 @@ actor_rollout_ref: enable: false dump_dir: ~/.verl/rollout_dump max_dump_step: 1 + dump_steps: null action: cache skip_tokenizer_init: true enable_rollout_routing_replay: false diff --git a/verl/trainer/config/_generated_ppo_veomni_trainer.yaml b/verl/trainer/config/_generated_ppo_veomni_trainer.yaml index d8193518338..17f3a753b40 100644 --- a/verl/trainer/config/_generated_ppo_veomni_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_veomni_trainer.yaml @@ -282,6 +282,7 @@ actor_rollout_ref: enable: false dump_dir: ~/.verl/rollout_dump max_dump_step: 1 + dump_steps: null action: cache skip_tokenizer_init: true enable_rollout_routing_replay: false diff --git a/verl/trainer/config/rollout/rollout.yaml b/verl/trainer/config/rollout/rollout.yaml index 01e5f2265b5..29860d0b11c 100644 --- a/verl/trainer/config/rollout/rollout.yaml +++ b/verl/trainer/config/rollout/rollout.yaml @@ -317,10 +317,14 @@ skip: # Number of training steps (from start) for which dump/load is active. Steps with train_step <= max_dump_step are "dump steps" (try load first, then generate and dump if missing). max_dump_step: 1 - # Action when beyond dump steps (gen_step > max_dump_step): - # - "cache": If dumped data exists for current step, use it; otherwise generate and dump. - # - "repeat": Reuse dumped data in a round-robin over the first max_dump_step steps. - # - "repeat_last": Reuse the last dumped step's data. + # Optional: only dump/load on these 1-based steps, e.g. [1, 2, 5]. If null or [], use max_dump_step above. + dump_steps: null + + # When the current training step is NOT a dump step (see dump_steps / max_dump_step and ``is_dump_step``): + # - "cache": Always run real generation (no reuse of rollout dumps on these steps). + # - "repeat": Round-robin load from genstep_* dirs that were written on earlier dump steps (``list_dumped_steps``). + # - "repeat_last": Load only from the most recently dumped genstep_*. + # On dump steps, behavior is always: try load existing dump for this gen_step, else generate and save. action: "cache" # Whether to skip tokenizer initialization for rollout engine diff --git a/verl/utils/rollout_skip.py b/verl/utils/rollout_skip.py index 04414265ea7..6768e99a583 100644 --- a/verl/utils/rollout_skip.py +++ b/verl/utils/rollout_skip.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import json +import warnings from enum import Enum from pathlib import Path from typing import Any, Callable @@ -66,15 +67,20 @@ def _find_last_gen_step_for_train_step(step_file: Path, target_train_step: int) class SkipAction(Enum): - CACHE = "cache" # cache the sample. If dump_date is found, use it. If not found, dump it. - REPEAT = "repeat" # Repeat the sample when gen_step reach skip.max_dump_step - REPEAT_LAST = "repeat_last" # Repeat the last sample when gen_step reach skip.max_dump_step + # Used only when ``is_dump_step`` is False (see ``is_dump_step``). On dump steps, the wrapper always + # tries load-from-disk first, then generate+dump if missing. + CACHE = "cache" # Always run real generation; no reloading from rollout dump dirs. + REPEAT = ( + "repeat" # Round-robin reload from ``genstep_*`` dirs recorded in ``list_dumped_steps`` (from earlier dumps). + ) + REPEAT_LAST = "repeat_last" # Reload only from the last dumped ``genstep_*`` (``list_dumped_steps[-1]``). class RolloutSkip: """ - RolloutSkip skips sequence generation during rollout by attempting to load previously dumped data. - If no dumped data is found, it generates new sequences and saves them to disk. + RolloutSkip can reuse disk-cached rollout batches: on **dump steps** (see ``is_dump_step``), it tries to + load ``genstep_*`` data first and only generates+writes if missing. On **non-dump** steps, behavior is + controlled by ``skip.action`` (``cache`` / ``repeat`` / ``repeat_last``); see ``SkipAction``. Args: config: The configuration object containing rollout settings. @@ -118,7 +124,33 @@ def __init__(self, config, rollout_wg) -> None: self.action = _get_skip_attr(self.skip_config, "action", SkipAction.REPEAT) self.action = SkipAction(self.action) - if self.max_dump_step <= 0: + raw_steps = _get_skip_attr(self.skip_config, "dump_steps", None) + self._dump_step_set: frozenset[int] | None = None + if raw_steps is not None: + if isinstance(raw_steps, str): + try: + raw_steps = json.loads(raw_steps) + except json.JSONDecodeError: + warnings.warn( + f"{self.print_mark}Could not parse 'dump_steps' string: {raw_steps!r}. " + "Falling back to max_dump_step.", + stacklevel=2, + ) + raw_steps = None + if raw_steps: + try: + step_list = [int(x) for x in raw_steps] + if step_list: + self._dump_step_set = frozenset(step_list) + except (ValueError, TypeError): + warnings.warn( + f"{self.print_mark}'dump_steps' must be a list of integers, got: {raw_steps!r}. " + "Falling back to max_dump_step.", + stacklevel=2, + ) + + _use_max_dump_window = self._dump_step_set is None + if _use_max_dump_window and self.max_dump_step <= 0: assert self.action in [SkipAction.CACHE] self._create_dump_path() @@ -133,10 +165,23 @@ def is_active(self) -> bool: @property def is_dump_step(self) -> bool: """ - Determine if the current step is a dump step based on the configured dump interval. - If train_step is given, it follows the train_step, otherwise it follows the gen_step. + Whether the current training step should run the dump/load path (try disk cache before generating). + + **Explicit list** — If ``dump_steps`` is non-empty, only steps whose index is in that list + are dump steps. Step indices are 1-based (first update is 1), same counter as used for + ``max_dump_step`` below—not a different numbering scheme. + + **Window** — If ``dump_steps`` is null or empty, dump/load while ``curr_train_step <= max_dump_step``. + + ``curr_train_step`` follows ``record(..., global_steps=...)`` when the trainer passes it; + otherwise it is advanced by ``step()`` each rollout. """ - return self.is_active and self.curr_train_step <= self.max_dump_step + if not self.is_active: + return False + t = self.curr_train_step + if self._dump_step_set is not None: + return t in self._dump_step_set + return t <= self.max_dump_step @property def num_dumped_step(self) -> int: @@ -238,7 +283,7 @@ def record( if found is not None: last_train_step, last_gen_step = found if last_train_step + 1 != global_steps: - print(f"{self.print_mark}\033[31mWarning: Train step not continues.\033[0m") + print(f"{self.print_mark}\033[31mWarning: Train step not continuous.\033[0m") self.__gen_offset_step = last_gen_step except Exception as e: print( diff --git a/verl/workers/config/rollout.py b/verl/workers/config/rollout.py index 59427ac10cd..b74b8073662 100644 --- a/verl/workers/config/rollout.py +++ b/verl/workers/config/rollout.py @@ -47,6 +47,10 @@ class SkipConfig(BaseConfig): enable: bool = False dump_dir: str = "~/.verl/rollout_dump" max_dump_step: int = 1 + # If set (non-empty), dump/load only on these 1-based steps. + # If null or empty, fall back to the max_dump_step window. + dump_steps: Optional[list[int]] = None + # When not a dump step: cache=always generate; repeat=round-robin from dumped gen_steps; repeat_last=last dump only. action: str = "cache" # cache | repeat | repeat_last def get(self, key: str, default=None):