diff --git a/src/prime_rl/configs/orchestrator.py b/src/prime_rl/configs/orchestrator.py index f9c126ebac..8b742bef70 100644 --- a/src/prime_rl/configs/orchestrator.py +++ b/src/prime_rl/configs/orchestrator.py @@ -754,6 +754,14 @@ class OrchestratorConfig(BaseConfig): # The advantage configuration advantage: AdvantageConfig | None = DefaultAdvantageConfig() + # Filter zero advantages + filter_zero_advantages: Annotated[ + bool, + Field( + description="Whether to filter out training samples with zero advantage. If True, samples with advantage == 0.0 are not sent to the trainer.", + ), + ] = True + # Rollout filters (monitor by default, enforce optionally) filters: list[FilterConfig] = [GibberishFilterConfig(), RepetitionFilterConfig()] diff --git a/src/prime_rl/orchestrator/filters.py b/src/prime_rl/orchestrator/filters.py index 435a1f4af2..52ac1aabb6 100644 --- a/src/prime_rl/orchestrator/filters.py +++ b/src/prime_rl/orchestrator/filters.py @@ -95,6 +95,24 @@ def check(self, rollout: vf.RolloutOutput) -> FilterResult: return FilterResult(detected=False) +@dataclass +class ZeroAdvantageFilter: + """Flags rollouts with zero advantage. + + This filter is applied after advantages are computed and checks if the + rollout's advantage field is zero. + """ + + name: str + enforce: bool = True + + def check(self, rollout: vf.RolloutOutput) -> FilterResult: + advantage = rollout.get("advantage") + if advantage is not None and advantage == 0.0: + return FilterResult(detected=True) + return FilterResult(detected=False) + + def setup_filter(config: FilterConfig, vocab_size: int) -> RolloutFilter: """Create a RolloutFilter from a filter config.""" if config.type == "gibberish": diff --git a/src/prime_rl/orchestrator/orchestrator.py b/src/prime_rl/orchestrator/orchestrator.py index 5227d0358a..314542acc2 100644 --- a/src/prime_rl/orchestrator/orchestrator.py +++ b/src/prime_rl/orchestrator/orchestrator.py @@ -564,6 +564,18 @@ def _cleanup_env_processes(): config.advantage, ) + # Store advantages on rollouts for filtering + for rollout, advantage in zip(train_rollouts, advantages): + rollout["advantage"] = advantage + + # Apply zero advantage filter if configured + if config.filter_zero_advantages: + from prime_rl.orchestrator.filters import ZeroAdvantageFilter + + zero_advantage_filter = ZeroAdvantageFilter(name="zero_advantage", enforce=True) + zero_advantage_metrics = apply_filters([zero_advantage_filter], train_rollouts) + filter_metrics.update(zero_advantage_metrics) + # Convert rollouts to training samples parallel_preprocess_start = time.perf_counter()