diff --git a/docs/content/docs/configuration/config.mdx b/docs/content/docs/configuration/config.mdx index 6b4a5d3731..83022572bd 100644 --- a/docs/content/docs/configuration/config.mdx +++ b/docs/content/docs/configuration/config.mdx @@ -672,6 +672,8 @@ generator: zero_reward_on_non_stop: false + skip_failed_rollouts: false + apply_overlong_filtering: false ``` @@ -747,5 +749,6 @@ For more details on how different placement options work, please refer to the [p ### Misc Configuration - `generator.zero_reward_on_non_stop`: Whether to set the reward to 0 if the `stop_reason` is not `stop`. Cases where this is useful: Often, we have format rewards for the LLM to follow, but in cases where the LLM didn't finish the response, we typically don't want to reward it. This is a general setting for all environments. +- `generator.skip_failed_rollouts`: Whether to skip individual failed non-batched rollouts by replacing each failed row with a zero-reward, loss-masked placeholder whose `stop_reason` is `rollout_error`. This catches normal rollout exceptions only; cancellations and interrupts still stop the training step. - `generator.apply_overlong_filtering`: Whether to apply DAPO Overlong Filtering to the loss masks. For each trajectory that exceeds the max length (i.e., truncated and does not end with an EOS token), this masks out every token in the loss mask. - `generator.step_wise_trajectories`: Whether to return outputs in a step-wise fashion. If `true`, then the generator will return multi-turn generations with the (prompt, response) pair of each turn being a separate trajectory. Advantages are computed based on the last step of each trajectory and propagated to the previous steps. diff --git a/skyrl/train/config/config.py b/skyrl/train/config/config.py index f28d997ba4..6ad987b324 100644 --- a/skyrl/train/config/config.py +++ b/skyrl/train/config/config.py @@ -544,6 +544,8 @@ class GeneratorConfig(BaseConfig): eval_n_samples_per_prompt: int = 1 zero_reward_on_non_stop: bool = False """Set reward to 0 when ``stop_reason`` is not ``"stop"`` (i.e., generation was truncated or aborted).""" + skip_failed_rollouts: bool = False + """Replace failed non-batched rollouts with zero-reward, loss-masked placeholders.""" apply_overlong_filtering: bool = False """Apply DAPO Overlong Filtering: mask out all tokens in the loss mask for trajectories that exceed max length (truncated, no EOS token).""" diff --git a/skyrl/train/config/ppo_base_config.yaml b/skyrl/train/config/ppo_base_config.yaml index b10695c5b3..d5f1820379 100644 --- a/skyrl/train/config/ppo_base_config.yaml +++ b/skyrl/train/config/ppo_base_config.yaml @@ -380,6 +380,11 @@ generator: # TODO (erictang000): Show clear ablations for benefits of this on GSM8K or SQL. zero_reward_on_non_stop: false + # Whether to skip individual failed rollouts by substituting zero-reward, + # loss-masked placeholder rows with stop_reason="rollout_error". + # This is only supported for non-batched generation. + skip_failed_rollouts: false + # Whether to apply DAPO Overlong Filtering to the loss masks. # For each trajectory that exceeds the max length (i.e., truncated and does not end with an # EOS token), this masks out every token in the loss mask. @@ -395,4 +400,4 @@ generator: environment: env_class: "gsm8k" # NOTE: environment specific defaults for environment.skyrl_gym are set at the following path: - # skyrl_gym: config/skyrl_gym_config/default.yaml \ No newline at end of file + # skyrl_gym: config/skyrl_gym_config/default.yaml diff --git a/skyrl/train/generators/skyrl_gym_generator.py b/skyrl/train/generators/skyrl_gym_generator.py index 68b635e8c9..b53f505323 100644 --- a/skyrl/train/generators/skyrl_gym_generator.py +++ b/skyrl/train/generators/skyrl_gym_generator.py @@ -9,7 +9,7 @@ import copy from concurrent.futures import ThreadPoolExecutor from dataclasses import asdict, dataclass -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union from uuid import uuid4 import torch @@ -32,6 +32,7 @@ TrajectoryID, ) from skyrl.train.generators.utils import ( + ROLLOUT_ERROR_STOP_REASON, apply_overlong_filtering, get_custom_chat_template, get_generation_prompt_ids, @@ -169,7 +170,8 @@ def __init__( self.generation_prompt_ids = get_generation_prompt_ids(tokenizer) if self.use_conversation_multi_turn else None if self.skyrl_gym_cfg.max_env_workers > 0: self.env_executor = ThreadPoolExecutor( - max_workers=self.skyrl_gym_cfg.max_env_workers, thread_name_prefix="skyrl-gym-env-" + max_workers=self.skyrl_gym_cfg.max_env_workers, + thread_name_prefix="skyrl-gym-env-", ) else: self.env_executor = None @@ -205,6 +207,8 @@ def _validate_cfg(self, generator_cfg: GeneratorConfig): raise ValueError( "`chat_template_kwargs` is not compatible with `batched=True` since the chat templating is handled by the inference engine" ) + if generator_cfg.skip_failed_rollouts and generator_cfg.batched: + raise ValueError("`skip_failed_rollouts=True` is only supported with `batched=False`.") if self.generator_cfg.step_wise_trajectories: if self.batched: @@ -228,6 +232,80 @@ async def _run_in_executor_if_available(self, func, *args, **kwargs): else: return func(*args, **kwargs) + async def _close_env_after_exception(self, env, context: str): + try: + await self._run_in_executor_if_available(env.close) + except Exception as close_exc: + logger.opt(exception=close_exc).warning( + "Failed to close SkyRL-Gym environment after {} failure: {}", + context, + close_exc, + ) + + def _uses_rollout_logprobs(self, sampling_params: Optional[dict]) -> bool: + if sampling_params is not None: + return sampling_params.get("logprobs", None) is not None + return self.generator_cfg.sampling_params.logprobs is not None + + def _placeholder_token_id(self) -> int: + for attr_name in ("eos_token_id", "pad_token_id"): + token_id = getattr(self.tokenizer, attr_name, None) + if token_id is not None: + return token_id + return 0 + + def _failed_rollout_placeholder(self, include_logprobs: bool) -> Union[TrajectoryOutput, StepWiseOutput]: + token_id = self._placeholder_token_id() + reward: Union[float, List[float]] = ( + 0.0 if self.custom_chat_template and not self.generator_cfg.step_wise_trajectories else [0.0] + ) + output = TrajectoryOutput( + response_ids=[token_id], + reward=reward, + stop_reason=ROLLOUT_ERROR_STOP_REASON, + loss_mask=[0], + prompt_ids=[token_id], + rollout_logprobs=[0.0] if include_logprobs else None, + env_metrics={}, + rollout_expert_indices=None, + ) + if self.generator_cfg.step_wise_trajectories: + return StepWiseOutput(step_outputs=[output]) + return output + + async def _safe_rollout( + self, + idx: int, + env_class: str, + trajectory_id: Optional[TrajectoryID], + rollout: Awaitable[Union[TrajectoryOutput, StepWiseOutput]], + include_logprobs: bool, + ) -> Union[TrajectoryOutput, StepWiseOutput]: + try: + return await rollout + except asyncio.CancelledError: + raise + except Exception as exc: + trajectory = trajectory_id.to_string() if trajectory_id is not None else None + logger.opt(exception=exc).warning( + "SkyRLGym rollout {} failed for env_class={} trajectory_id={} with {}: {}; " + "substituting zero-reward placeholder with stop_reason={}", + idx, + env_class, + trajectory, + type(exc).__name__, + exc, + ROLLOUT_ERROR_STOP_REASON, + ) + return self._failed_rollout_placeholder(include_logprobs=include_logprobs) + + def _normalize_optional_tensor_features(self, values: List[Optional[torch.Tensor]]) -> Optional[List[torch.Tensor]]: + ref = next((value for value in values if value is not None), None) + if ref is None: + return None + placeholder = torch.empty(0, *ref.shape[1:], dtype=ref.dtype, device=ref.device) + return [value if value is not None else placeholder for value in values] + # ------------------------------------------------------------------ # Subclass hooks. Default implementations are no-ops so generic envs # see the upstream behavior; subclasses (e.g. RLMGymGenerator) override. @@ -313,18 +391,26 @@ async def agent_loop( chat_history = copy.deepcopy(prompt) # init() returns the first prompt to be given to the model, and optional metadata dict - chat_history, _ = await self._run_in_executor_if_available(env.init, chat_history) + try: + chat_history, _ = await self._run_in_executor_if_available(env.init, chat_history) + except Exception: + await self._close_env_after_exception(env, "env.init") + raise initial_chat_history_length = len(chat_history) - initial_input_ids = self.tokenizer.apply_chat_template( - chat_history, - # If retokenize_chat_history==True, avoid including the generation prompt in both the - # prompt_ids and response_ids due to how `response_encodings["input_ids"]` works. - add_generation_prompt=not retokenize_chat_history, - chat_template=self.custom_chat_template if retokenize_chat_history else None, - tokenize=True, - return_dict=False, - **self.generator_cfg.chat_template_kwargs, - ) + try: + initial_input_ids = self.tokenizer.apply_chat_template( + chat_history, + # If retokenize_chat_history==True, avoid including the generation prompt in both the + # prompt_ids and response_ids due to how `response_encodings["input_ids"]` works. + add_generation_prompt=not retokenize_chat_history, + chat_template=self.custom_chat_template if retokenize_chat_history else None, + tokenize=True, + return_dict=False, + **self.generator_cfg.chat_template_kwargs, + ) + except Exception: + await self._close_env_after_exception(env, "initial chat templating") + raise initial_prompt_length = len(initial_input_ids) loss_mask = [] # this excludes the prompt @@ -343,7 +429,7 @@ async def agent_loop( agent_loop_output = StepWiseOutput(step_outputs=[]) if is_step_wise else None - get_logprobs = self.generator_cfg.sampling_params.logprobs is not None + get_logprobs = self._uses_rollout_logprobs(sampling_params) agent_loop_state = AgentLoopState( chat_history=chat_history, input_ids=initial_input_ids, @@ -354,7 +440,6 @@ async def agent_loop( ) while not agent_loop_state.done: - if len(agent_loop_state.input_ids) > max_input_length: stop_reason = "length" break @@ -374,9 +459,15 @@ async def agent_loop( agent_loop_state.rollout_logprobs = None engine_input = InferenceEngineInput( - prompt_token_ids=[agent_loop_state.input_ids], session_ids=[session_id], sampling_params=sampling_params + prompt_token_ids=[agent_loop_state.input_ids], + session_ids=[session_id], + sampling_params=sampling_params, ) - engine_output = await self.inference_engine_client.generate(engine_input, model=self.policy_model_name) + try: + engine_output = await self.inference_engine_client.generate(engine_input, model=self.policy_model_name) + except Exception: + await self._close_env_after_exception(env, "inference generation") + raise output = engine_output["responses"][0] output_ids = engine_output["response_ids"][0] stop_reason = engine_output["stop_reasons"][0] @@ -408,7 +499,11 @@ async def agent_loop( added_eos = True # 2. Environment step - env_step_output: BaseTextEnvStepOutput = await self._run_in_executor_if_available(env.step, output) + try: + env_step_output: BaseTextEnvStepOutput = await self._run_in_executor_if_available(env.step, output) + except Exception: + await self._close_env_after_exception(env, "env.step") + raise new_obs = env_step_output["observations"] step_reward: float = env_step_output["reward"] agent_loop_state.done = env_step_output["done"] @@ -571,7 +666,10 @@ async def agent_loop( return agent_loop_output def _build_per_token_rewards( - self, per_step_rewards: List[Tuple[float, Optional[int]]], response_ids: List[int], appended_eos_token: bool + self, + per_step_rewards: List[Tuple[float, Optional[int]]], + response_ids: List[int], + appended_eos_token: bool, ) -> Union[float, List[float]]: """ Build reward output from per-step rewards. @@ -794,20 +892,24 @@ async def generate(self, input_batch: GeneratorInput, disable_tqdm: bool = False if self.batched: return await self.generate_batched(prompts, env_classes, env_extras, max_tokens, sampling_params) + get_logprobs = self._uses_rollout_logprobs(sampling_params) + # Async agent loop to generate trajectories in parallel. tasks = [] for i in range(len(prompts)): - tasks.append( - self.agent_loop( - prompts[i], - env_classes[i], - env_extras[i], - max_tokens, - max_input_length, - sampling_params=sampling_params, - trajectory_id=trajectory_ids[i] if trajectory_ids is not None else None, - ) + trajectory_id = trajectory_ids[i] if trajectory_ids is not None else None + rollout = self.agent_loop( + prompts[i], + env_classes[i], + env_extras[i], + max_tokens, + max_input_length, + sampling_params=sampling_params, + trajectory_id=trajectory_id, ) + if self.generator_cfg.skip_failed_rollouts: + rollout = self._safe_rollout(i, env_classes[i], trajectory_id, rollout, get_logprobs) + tasks.append(rollout) all_outputs = await tqdm.gather( *tasks, @@ -850,18 +952,15 @@ async def generate(self, input_batch: GeneratorInput, disable_tqdm: bool = False out_trajectory_ids = None has_vision_features = any(getattr(output, "pixel_values", None) is not None for output in all_outputs) - pixel_values = ( - [getattr(output, "pixel_values", None) for output in all_outputs] if has_vision_features else None - ) - image_grid_thw = ( - [getattr(output, "image_grid_thw", None) for output in all_outputs] if has_vision_features else None - ) - - if sampling_params is not None: - # sampling params will be a dict in the format of the inference engine backend - get_logprobs = sampling_params.get("logprobs", None) is not None - else: - get_logprobs = self.generator_cfg.sampling_params.logprobs is not None + pixel_values = None + image_grid_thw = None + if has_vision_features: + pixel_values = self._normalize_optional_tensor_features( + [getattr(output, "pixel_values", None) for output in all_outputs] + ) + image_grid_thw = self._normalize_optional_tensor_features( + [getattr(output, "image_grid_thw", None) for output in all_outputs] + ) if get_logprobs: if self.generator_cfg.step_wise_trajectories: @@ -880,6 +979,14 @@ async def generate(self, input_batch: GeneratorInput, disable_tqdm: bool = False rollout_expert_indices = None rollout_metrics = get_rollout_metrics(responses, rewards, env_metrics, env_classes, loss_masks) + if self.generator_cfg.skip_failed_rollouts: + num_rollout_errors = sum(reason == ROLLOUT_ERROR_STOP_REASON for reason in stop_reasons) + rollout_metrics["generate/num_rollout_errors"] = num_rollout_errors + rollout_metrics["generate/rollout_error_rate"] = num_rollout_errors / len(stop_reasons) + if num_rollout_errors == len(stop_reasons): + logger.warning( + "All SkyRLGym rollouts in this batch failed and were replaced with loss-masked placeholders." + ) if self.generator_cfg.zero_reward_on_non_stop: # set reward to 0 if the stop reason is not "stop" diff --git a/skyrl/train/generators/skyrl_vlm_generator.py b/skyrl/train/generators/skyrl_vlm_generator.py index 76779aac6e..b220901732 100644 --- a/skyrl/train/generators/skyrl_vlm_generator.py +++ b/skyrl/train/generators/skyrl_vlm_generator.py @@ -43,7 +43,13 @@ def __init__( tokenizer, policy_model_name: Optional[str] = None, ): - super().__init__(generator_cfg, skyrl_gym_cfg, inference_engine_client, tokenizer, policy_model_name) + super().__init__( + generator_cfg, + skyrl_gym_cfg, + inference_engine_client, + tokenizer, + policy_model_name, + ) logger.info("Initialized SkyRLVLMGymGenerator (VLM multi-modal generator)") def _validate_cfg(self, generator_cfg: GeneratorConfig): @@ -59,7 +65,12 @@ def _validate_cfg(self, generator_cfg: GeneratorConfig): async def _render_conversation(self, conversation: ConversationType) -> RenderedConversation: rendered = await self.inference_engine_client.render_chat_completion( - {"json": {"model": self.inference_engine_client.model_name, "messages": conversation}} + { + "json": { + "model": self.inference_engine_client.model_name, + "messages": conversation, + } + } ) return RenderedConversation(prompt_ids=rendered["token_ids"], features=rendered.get("features", None)) @@ -88,19 +99,27 @@ async def agent_loop( ) conversation = copy.deepcopy(prompt) - conversation, _ = await self._run_in_executor_if_available(env.init, conversation) + try: + conversation, _ = await self._run_in_executor_if_available(env.init, conversation) + except Exception: + await self._close_env_after_exception(env, "env.init") + raise # Render initial conversation → prompt_ids # latest_features always points to the most recent render's features # (each render covers the full conversation, so later renders supersede earlier ones) - initial_render = await self._render_conversation(conversation) + try: + initial_render = await self._render_conversation(conversation) + except Exception: + await self._close_env_after_exception(env, "initial conversation rendering") + raise prompt_ids = initial_render["prompt_ids"] latest_features = initial_render["features"] current_sampling_params: dict = ( sampling_params if sampling_params is not None else asdict(self.generator_cfg.sampling_params) ) - get_logprobs = self.generator_cfg.sampling_params.logprobs is not None + get_logprobs = self._uses_rollout_logprobs(sampling_params) stop_strs = current_sampling_params.get("stop", None) # ── Accumulators ─────────────────────────────────────────────── @@ -125,7 +144,11 @@ async def agent_loop( while not done: # 1. Render full conversation for this turn's generation input - rendered_conversation = await self._render_conversation(conversation) + try: + rendered_conversation = await self._render_conversation(conversation) + except Exception: + await self._close_env_after_exception(env, "conversation rendering") + raise input_ids = rendered_conversation["prompt_ids"] latest_features = rendered_conversation["features"] @@ -149,7 +172,11 @@ async def agent_loop( sampling_params=current_sampling_params, mm_features=[latest_features] if latest_features is not None else None, ) - engine_output = await self.inference_engine_client.generate(engine_input, model=self.policy_model_name) + try: + engine_output = await self.inference_engine_client.generate(engine_input, model=self.policy_model_name) + except Exception: + await self._close_env_after_exception(env, "inference generation") + raise gen_text = engine_output["responses"][0] gen_ids = engine_output["response_ids"][0] @@ -166,7 +193,11 @@ async def agent_loop( added_eos = True # 3. Environment step - env_step_output = await self._run_in_executor_if_available(env.step, gen_text) + try: + env_step_output = await self._run_in_executor_if_available(env.step, gen_text) + except Exception: + await self._close_env_after_exception(env, "env.step") + raise new_obs = env_step_output["observations"] step_reward: float = env_step_output["reward"] done = env_step_output["done"] diff --git a/skyrl/train/generators/utils.py b/skyrl/train/generators/utils.py index ec56abbea9..886bd4e4a6 100644 --- a/skyrl/train/generators/utils.py +++ b/skyrl/train/generators/utils.py @@ -19,6 +19,8 @@ ) from skyrl_gym.metrics import aggregate_for_environment +ROLLOUT_ERROR_STOP_REASON = "rollout_error" + def _validate_template_file_path(file_path: str) -> str: """ @@ -302,6 +304,14 @@ def concatenate_generator_outputs(generator_outputs: List[GeneratorOutput], step rollout_metrics[k] = max(values) else: rollout_metrics[k] = sum(values) + + has_rollout_error_metric = any( + "generate/num_rollout_errors" in (go.get("rollout_metrics") or {}) for go in generator_outputs + ) + if result.get("stop_reasons") is not None and has_rollout_error_metric: + num_rollout_errors = sum(reason == ROLLOUT_ERROR_STOP_REASON for reason in result["stop_reasons"]) + rollout_metrics["generate/num_rollout_errors"] = num_rollout_errors + rollout_metrics["generate/rollout_error_rate"] = num_rollout_errors / len(result["stop_reasons"]) result["rollout_metrics"] = rollout_metrics # Validate the generator output using the number of prompts diff --git a/tests/train/generators/test_generator_output_utils.py b/tests/train/generators/test_generator_output_utils.py index 1c2ec2bdba..58db930610 100644 --- a/tests/train/generators/test_generator_output_utils.py +++ b/tests/train/generators/test_generator_output_utils.py @@ -9,6 +9,7 @@ from skyrl.train.generators.base import GeneratorOutput, TrajectoryID from skyrl.train.generators.utils import ( + ROLLOUT_ERROR_STOP_REASON, concatenate_generator_outputs, get_metrics_from_generator_output, merge_stepwise_output, @@ -87,6 +88,38 @@ def test_generator_output_concatenation(): np.testing.assert_allclose(concatenated_output["rollout_metrics"][key], value) +def test_generator_output_concatenation_recomputes_rollout_error_rate(): + generator_output_1: GeneratorOutput = { + "prompt_token_ids": [[1], [2]], + "response_ids": [[1], [2]], + "rewards": [[0.0], [1.0]], + "loss_masks": [[0], [1]], + "stop_reasons": [ROLLOUT_ERROR_STOP_REASON, "stop"], + "rollout_logprobs": None, + "rollout_metrics": { + "generate/num_rollout_errors": 1, + "generate/rollout_error_rate": 0.5, + }, + } + generator_output_2: GeneratorOutput = { + "prompt_token_ids": [[3], [4], [5]], + "response_ids": [[3], [4], [5]], + "rewards": [[0.0], [0.0], [1.0]], + "loss_masks": [[0], [0], [1]], + "stop_reasons": [ROLLOUT_ERROR_STOP_REASON, ROLLOUT_ERROR_STOP_REASON, "stop"], + "rollout_logprobs": None, + "rollout_metrics": { + "generate/num_rollout_errors": 2, + "generate/rollout_error_rate": 2 / 3, + }, + } + + concatenated_output = concatenate_generator_outputs([generator_output_1, generator_output_2]) + + assert concatenated_output["rollout_metrics"]["generate/num_rollout_errors"] == 3 + assert concatenated_output["rollout_metrics"]["generate/rollout_error_rate"] == 3 / 5 + + def test_get_metrics_from_generator_output(): # Per trajectory rewards, where rewards are List[float] generator_output: GeneratorOutput = { diff --git a/tests/train/generators/test_skyrl_gym_generator.py b/tests/train/generators/test_skyrl_gym_generator.py index d8bee68a1a..944db60c2a 100644 --- a/tests/train/generators/test_skyrl_gym_generator.py +++ b/tests/train/generators/test_skyrl_gym_generator.py @@ -2,6 +2,7 @@ uv run --extra dev --isolated pytest tests/train/generators/test_skyrl_gym_generator.py """ +import asyncio from typing import Any, Dict, List from unittest.mock import AsyncMock, MagicMock, patch @@ -12,8 +13,14 @@ ConversationType, GeneratorInput, GeneratorOutput, + TrajectoryID, +) +from skyrl.train.generators.skyrl_gym_generator import ( + ROLLOUT_ERROR_STOP_REASON, + SkyRLGymGenerator, + StepWiseOutput, + TrajectoryOutput, ) -from skyrl.train.generators.skyrl_gym_generator import SkyRLGymGenerator from skyrl_gym.envs.base_text_env import BaseTextEnv, BaseTextEnvStepOutput # Mock constants, where 4 is the eos token id @@ -435,6 +442,217 @@ async def test_generate_interface_compliance( assert validate_generator_input(input_batch_with_none), "Input with None env_extras should be valid" +def _make_test_generator(generator_cfg, mock_env_cfg, mock_llm, mock_tokenizer) -> SkyRLGymGenerator: + generator_cfg.batched = False + generator_cfg.max_input_length = 512 + generator_cfg.chat_template.source = "name" + generator_cfg.chat_template.name_or_path = None + return SkyRLGymGenerator( + generator_cfg=generator_cfg, + skyrl_gym_cfg=mock_env_cfg, + inference_engine_client=mock_llm, + tokenizer=mock_tokenizer, + ) + + +def _generator_input(num_prompts: int, env_class: str, trajectory_ids=None) -> GeneratorInput: + input_batch: GeneratorInput = { + "prompts": [[{"role": "user", "content": f"prompt {i}"}] for i in range(num_prompts)], + "env_extras": [{"idx": i} for i in range(num_prompts)], + "env_classes": [env_class for _ in range(num_prompts)], + } + if trajectory_ids is not None: + input_batch["trajectory_ids"] = trajectory_ids + return input_batch + + +def _successful_rollout( + response_ids: List[int] | None = None, + rollout_logprobs: List[float] | None = None, +) -> TrajectoryOutput: + response_ids = response_ids or [11, 12] + return TrajectoryOutput( + response_ids=response_ids, + reward=[0.0] * (len(response_ids) - 1) + [1.0], + stop_reason="stop", + loss_mask=[1] * len(response_ids), + prompt_ids=[101, 102], + rollout_logprobs=rollout_logprobs, + env_metrics={"ok": 1.0}, + ) + + +@pytest.mark.asyncio +async def test_generate_raises_rollout_exception_by_default(mock_tokenizer, mock_llm, generator_cfg, mock_env_cfg): + generator = _make_test_generator(generator_cfg, mock_env_cfg, mock_llm, mock_tokenizer) + + async def failing_agent_loop(*args, **kwargs): + raise RuntimeError("rollout failed") + + generator.agent_loop = failing_agent_loop + + with pytest.raises(RuntimeError, match="rollout failed"): + await generator.generate(_generator_input(1, mock_env_cfg.env_class)) + + +@pytest.mark.asyncio +async def test_generate_skip_failed_rollouts_substitutes_placeholder( + mock_tokenizer, mock_llm, generator_cfg, mock_env_cfg +): + generator_cfg.skip_failed_rollouts = True + generator = _make_test_generator(generator_cfg, mock_env_cfg, mock_llm, mock_tokenizer) + + async def agent_loop(prompt, *args, **kwargs): + if prompt[0]["content"] == "prompt 1": + raise RuntimeError("rollout failed") + return _successful_rollout() + + generator.agent_loop = agent_loop + + output = await generator.generate(_generator_input(2, mock_env_cfg.env_class)) + + assert output["response_ids"] == [[11, 12], [mock_tokenizer.eos_token_id]] + assert output["rewards"] == [[0.0, 1.0], [0.0]] + assert output["loss_masks"] == [[1, 1], [0]] + assert output["stop_reasons"] == ["stop", ROLLOUT_ERROR_STOP_REASON] + assert output["prompt_token_ids"] == [[101, 102], [mock_tokenizer.eos_token_id]] + assert output["rollout_metrics"]["generate/num_rollout_errors"] == 1 + assert output["rollout_metrics"]["generate/rollout_error_rate"] == 0.5 + assert validate_generator_output(output) + + +@pytest.mark.asyncio +@patch("skyrl_gym.make") +async def test_generate_skip_failed_rollouts_closes_env_after_rollout_failure( + mock_make, mock_tokenizer, mock_llm, generator_cfg, mock_env_cfg +): + generator_cfg.skip_failed_rollouts = True + generator = _make_test_generator(generator_cfg, mock_env_cfg, mock_llm, mock_tokenizer) + env = MagicMock() + env.init.side_effect = RuntimeError("init failed") + env.close.return_value = None + mock_make.return_value = env + + output = await generator.generate(_generator_input(1, mock_env_cfg.env_class)) + + env.close.assert_called_once() + assert output["response_ids"] == [[mock_tokenizer.eos_token_id]] + assert output["rewards"] == [[0.0]] + assert output["loss_masks"] == [[0]] + assert output["stop_reasons"] == [ROLLOUT_ERROR_STOP_REASON] + assert output["rollout_metrics"]["generate/num_rollout_errors"] == 1 + assert output["rollout_metrics"]["generate/rollout_error_rate"] == 1.0 + + +@pytest.mark.asyncio +async def test_generate_skip_failed_rollouts_includes_placeholder_logprobs( + mock_tokenizer, mock_llm, generator_cfg, mock_env_cfg +): + generator_cfg.skip_failed_rollouts = True + generator_cfg.sampling_params.logprobs = 1 + generator = _make_test_generator(generator_cfg, mock_env_cfg, mock_llm, mock_tokenizer) + + async def agent_loop(prompt, *args, **kwargs): + if prompt[0]["content"] == "prompt 1": + raise RuntimeError("rollout failed") + return _successful_rollout(rollout_logprobs=[-0.1, -0.2]) + + generator.agent_loop = agent_loop + + output = await generator.generate(_generator_input(2, mock_env_cfg.env_class)) + + assert output["rollout_logprobs"] == [[-0.1, -0.2], [0.0]] + assert output["response_ids"][1] == [mock_tokenizer.eos_token_id] + assert output["loss_masks"][1] == [0] + assert validate_generator_output(output) + + +@pytest.mark.asyncio +async def test_generate_skip_failed_rollouts_step_wise_placeholder( + mock_tokenizer, mock_llm, generator_cfg, mock_env_cfg +): + generator_cfg.skip_failed_rollouts = True + generator_cfg.step_wise_trajectories = True + generator = _make_test_generator(generator_cfg, mock_env_cfg, mock_llm, mock_tokenizer) + trajectory_ids = [ + TrajectoryID(instance_id="sample-a", repetition_id=0), + TrajectoryID(instance_id="sample-b", repetition_id=0), + ] + + async def agent_loop(prompt, *args, **kwargs): + if prompt[0]["content"] == "prompt 1": + raise RuntimeError("rollout failed") + return StepWiseOutput(step_outputs=[_successful_rollout(response_ids=[21])]) + + generator.agent_loop = agent_loop + + output = await generator.generate(_generator_input(2, mock_env_cfg.env_class, trajectory_ids=trajectory_ids)) + + assert output["response_ids"] == [[21], [mock_tokenizer.eos_token_id]] + assert output["rewards"] == [[1.0], [0.0]] + assert output["loss_masks"] == [[1], [0]] + assert output["stop_reasons"] == ["stop", ROLLOUT_ERROR_STOP_REASON] + assert output["trajectory_ids"] == trajectory_ids + assert output["is_last_step"] == [True, True] + assert validate_generator_output(output) + + +@pytest.mark.asyncio +async def test_generate_skip_failed_rollouts_preserves_cancellation( + mock_tokenizer, mock_llm, generator_cfg, mock_env_cfg +): + generator_cfg.skip_failed_rollouts = True + generator = _make_test_generator(generator_cfg, mock_env_cfg, mock_llm, mock_tokenizer) + + async def cancelled_agent_loop(*args, **kwargs): + raise asyncio.CancelledError() + + generator.agent_loop = cancelled_agent_loop + + with pytest.raises(asyncio.CancelledError): + await generator.generate(_generator_input(1, mock_env_cfg.env_class)) + + +@pytest.mark.asyncio +async def test_generate_skip_failed_rollouts_all_failed_batch(mock_tokenizer, mock_llm, generator_cfg, mock_env_cfg): + generator_cfg.skip_failed_rollouts = True + generator = _make_test_generator(generator_cfg, mock_env_cfg, mock_llm, mock_tokenizer) + + async def failing_agent_loop(*args, **kwargs): + raise RuntimeError("rollout failed") + + generator.agent_loop = failing_agent_loop + + output = await generator.generate(_generator_input(2, mock_env_cfg.env_class)) + + assert output["response_ids"] == [ + [mock_tokenizer.eos_token_id], + [mock_tokenizer.eos_token_id], + ] + assert output["rewards"] == [[0.0], [0.0]] + assert output["loss_masks"] == [[0], [0]] + assert output["stop_reasons"] == [ + ROLLOUT_ERROR_STOP_REASON, + ROLLOUT_ERROR_STOP_REASON, + ] + assert output["rollout_metrics"]["generate/num_rollout_errors"] == 2 + assert output["rollout_metrics"]["generate/rollout_error_rate"] == 1.0 + assert validate_generator_output(output) + + +def test_skip_failed_rollouts_rejects_batched_mode(mock_tokenizer, mock_llm, generator_cfg, mock_env_cfg): + generator_cfg.skip_failed_rollouts = True + generator_cfg.batched = True + + with pytest.raises(ValueError, match="skip_failed_rollouts=True"): + SkyRLGymGenerator( + generator_cfg=generator_cfg, + skyrl_gym_cfg=mock_env_cfg, + inference_engine_client=mock_llm, + tokenizer=mock_tokenizer, + ) + + @pytest.mark.asyncio @pytest.mark.parametrize("turns_to_exceed", [1, 3]) # Test single-turn and multi-turn scenarios @patch("skyrl_gym.make")