Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/prime_rl/configs/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,14 @@ class OrchestratorConfig(BaseConfig):
# The advantage configuration
advantage: AdvantageConfig | None = DefaultAdvantageConfig()

# Filter zero advantages
filter_zero_advantages: Annotated[
bool,
Field(
description="Whether to filter out training samples with zero advantage. If True, samples with advantage == 0.0 are not sent to the trainer.",
),
] = True
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New config field defaults to True, changing existing behavior

Medium Severity

filter_zero_advantages defaults to True, which silently changes training behavior for all existing deployments that don't set this field. Every existing config will now have zero-advantage samples masked out, altering training dynamics. Since the PR title describes this as an "option," defaulting to False would be the safer, non-breaking choice.

Fix in Cursor Fix in Web

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing CHANGELOG entry for new config field

Low Severity

A new config field filter_zero_advantages was added to src/prime_rl/configs/orchestrator.py but CHANGELOG.md was not updated. Per the project rules, any PR that modifies configuration structures in config files must include a corresponding changelog entry.

Fix in Cursor Fix in Web

Triggered by project rule: BugBot Instructions


# Rollout filters (monitor by default, enforce optionally)
filters: list[FilterConfig] = [GibberishFilterConfig(), RepetitionFilterConfig()]

Expand Down
18 changes: 18 additions & 0 deletions src/prime_rl/orchestrator/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,24 @@ def check(self, rollout: vf.RolloutOutput) -> FilterResult:
return FilterResult(detected=False)


@dataclass
class ZeroAdvantageFilter:
"""Flags rollouts with zero advantage.

This filter is applied after advantages are computed and checks if the
rollout's advantage field is zero.
"""

name: str
enforce: bool = True

def check(self, rollout: vf.RolloutOutput) -> FilterResult:
advantage = rollout.get("advantage")
if advantage is not None and advantage == 0.0:
return FilterResult(detected=True)
return FilterResult(detected=False)


def setup_filter(config: FilterConfig, vocab_size: int) -> RolloutFilter:
"""Create a RolloutFilter from a filter config."""
if config.type == "gibberish":
Expand Down
12 changes: 12 additions & 0 deletions src/prime_rl/orchestrator/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,18 @@ def _cleanup_env_processes():
config.advantage,
)

# Store advantages on rollouts for filtering
for rollout, advantage in zip(train_rollouts, advantages):
rollout["advantage"] = advantage
Comment on lines +567 to +569
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should make the compute_advantages function in-place to avoid this in the orch code?


# Apply zero advantage filter if configured
if config.filter_zero_advantages:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's make the config the same as the other filters. then we can just apply_filters with all the filters in one place instead of having a branch just for this option. only diff would be that enforce would be true by default

from prime_rl.orchestrator.filters import ZeroAdvantageFilter
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import at top


zero_advantage_filter = ZeroAdvantageFilter(name="zero_advantage", enforce=True)
zero_advantage_metrics = apply_filters([zero_advantage_filter], train_rollouts)
filter_metrics.update(zero_advantage_metrics)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aggregate filter metrics overwritten by second apply_filters call

Medium Severity

The filter_metrics.update(zero_advantage_metrics) call overwrites the filter/total_detected_rate and filter/total_enforced_rate keys that were set by the first apply_filters call (for gibberish/repetition filters). Since apply_filters always emits these two aggregate keys, the second invocation's values silently replace the originals, making the logged totals reflect only zero-advantage detections rather than the combined total across all filters.

Additional Locations (1)
Fix in Cursor Fix in Web


# Convert rollouts to training samples
parallel_preprocess_start = time.perf_counter()

Expand Down
Loading