diff --git a/src/prime_rl/configs/orchestrator.py b/src/prime_rl/configs/orchestrator.py index f9c126ebac..4a93df4e6f 100644 --- a/src/prime_rl/configs/orchestrator.py +++ b/src/prime_rl/configs/orchestrator.py @@ -562,10 +562,21 @@ class DefaultAdvantageConfig(BaseModel): model_config = ConfigDict(extra="forbid") type: Literal["default"] = "default" + length_shaping: Annotated[ + Literal["off", "brevity_bonus", "gr3"], + Field( + description=( + "Length shaping strategy for reward adjustment. " + "'brevity_bonus': attenuate correct rollouts by L_min/L_i (shortest correct keeps 1). " + "'gr3': multiplicatively scale all rollouts by (1 + alpha * L_i / L_mean)^-1. " + "'off': no length shaping." + ) + ), + ] = "off" length_shaping_alpha: Annotated[ - float | None, - Field(description="Penalty coefficient for Group Relative Reward Rescaling (GR³). Recommended value: 0.33"), - ] = None + float, + Field(description="Coefficient for gr3 length shaping. Only used when length_shaping='gr3'."), + ] = 0.33 class CustomAdvantageConfig(BaseModel): @@ -996,13 +1007,6 @@ def validate_verification_config(self): ) return self - @model_validator(mode="after") - def validate_length_shaping_requires_online_difficulty_filtering(self): - if isinstance(self.advantage, DefaultAdvantageConfig) and self.advantage.length_shaping_alpha is not None: - if not self.buffer.online_difficulty_filtering: - raise ValueError("Group Relative Reward (GR³) scaling requires online difficulty filtering") - return self - @model_validator(mode="after") def auto_setup_bench(self): if self.bench: diff --git a/src/prime_rl/orchestrator/advantage.py b/src/prime_rl/orchestrator/advantage.py index 25f942a044..25997e28d6 100644 --- a/src/prime_rl/orchestrator/advantage.py +++ b/src/prime_rl/orchestrator/advantage.py @@ -35,16 +35,28 @@ def my_advantage(inputs: AdvantageInputs, **kwargs) -> AdvantageOutputs: def default_advantage_fn( inputs: AdvantageInputs, - length_shaping_alpha: float | None = None, + length_shaping: str = "off", + length_shaping_alpha: float = 0.33, ) -> AdvantageOutputs: """Default GRPO advantage: reward minus per-problem baseline.""" rewards = inputs.rewards + completion_lengths = inputs.completion_lengths.to(dtype=rewards.dtype) - if length_shaping_alpha is not None: - completion_lengths = inputs.completion_lengths.to(dtype=rewards.dtype) + if length_shaping == "brevity_bonus": + correct_mask = rewards >= 1.0 + + # Shortest correct completion per problem (inf where no rollout is correct) + lengths_masked = completion_lengths.masked_fill(~correct_mask, float("inf")) + min_correct_length = lengths_masked.min(dim=1, keepdim=True).values + + # Correct rollouts: reward * L_min/L_i (shortest correct keeps 1, longer ones attenuated) + shaped = rewards * (min_correct_length / completion_lengths) + rewards = torch.where(correct_mask, shaped, rewards) + + elif length_shaping == "gr3": lengths_normalized = completion_lengths / completion_lengths.mean(dim=1, keepdim=True) - length_shaping = (1 + length_shaping_alpha * lengths_normalized) ** -1 - rewards = rewards * length_shaping + rewards = rewards * (1 + length_shaping_alpha * lengths_normalized) ** -1 + baseline = rewards.mean(dim=1, keepdim=True) return AdvantageOutputs(advantages=rewards - baseline) @@ -64,6 +76,7 @@ def advantage_fn(inputs: AdvantageInputs) -> AdvantageOutputs: def advantage_fn(inputs: AdvantageInputs) -> AdvantageOutputs: return default_advantage_fn( inputs, + length_shaping=config.length_shaping, length_shaping_alpha=config.length_shaping_alpha, ) diff --git a/tests/unit/orchestrator/test_advantage.py b/tests/unit/orchestrator/test_advantage.py index 01ccc2e60a..1aa6e7eb34 100644 --- a/tests/unit/orchestrator/test_advantage.py +++ b/tests/unit/orchestrator/test_advantage.py @@ -18,21 +18,65 @@ def test_default_advantage_fn_simple_mean(): result = default_advantage_fn(inputs) assert result.advantages.shape == (2, 3) - # Check that mean is subtracted per row assert torch.allclose(result.advantages.mean(dim=1), torch.zeros(2), atol=1e-6) -def test_default_advantage_fn_gr3_length_shaping(): +def test_length_shaping_only_penalizes_correct_rollouts(): + """Correct rollouts get attenuated by L_min/L_i; incorrect ones are unchanged.""" + inputs = AdvantageInputs( + rewards=torch.tensor([[1.0, 1.0, 0.0, 1.0]]), + completion_lengths=torch.tensor([[10, 30, 20, 20]]), + ) + result = default_advantage_fn(inputs, length_shaping="brevity_bonus") + + # min_correct = 10 + # shaped: [1*10/10, 1*10/30, 0, 1*10/20] = [1.0, 1/3, 0.0, 0.5] + shaped_rewards = torch.tensor([1.0, 1.0 / 3, 0.0, 0.5]) + expected = shaped_rewards - shaped_rewards.mean() + + assert torch.allclose(result.advantages, expected.unsqueeze(0), atol=1e-6) + + +def test_length_shaping_shortest_correct_keeps_full_reward(): + """The shortest correct rollout keeps reward=1.""" + inputs = AdvantageInputs( + rewards=torch.tensor([[1.0, 1.0, 1.0]]), + completion_lengths=torch.tensor([[10, 20, 40]]), + ) + result = default_advantage_fn(inputs, length_shaping="brevity_bonus") + + # shaped: [1*10/10, 1*10/20, 1*10/40] = [1.0, 0.5, 0.25] + shaped_rewards = torch.tensor([1.0, 0.5, 0.25]) + expected = shaped_rewards - shaped_rewards.mean() + + assert torch.allclose(result.advantages, expected.unsqueeze(0), atol=1e-6) + + +def test_length_shaping_no_correct_rollouts(): + """When no rollout is correct, length shaping has no effect.""" + inputs = AdvantageInputs( + rewards=torch.tensor([[0.0, 0.0, 0.0]]), + completion_lengths=torch.tensor([[10, 20, 15]]), + ) + result_with = default_advantage_fn(inputs, length_shaping="brevity_bonus") + result_without = default_advantage_fn(inputs) + + assert torch.allclose(result_with.advantages, result_without.advantages, atol=1e-6) + + +def test_gr3_length_shaping(): + """GR³: multiplicative shaping on all rollouts relative to mean length.""" inputs = AdvantageInputs( rewards=torch.tensor([[1.0, 0.5, 0.8]]), completion_lengths=torch.tensor([[10, 20, 10]]), ) + result = default_advantage_fn(inputs, length_shaping="gr3", length_shaping_alpha=0.33) - result = default_advantage_fn(inputs, length_shaping_alpha=0.33) - + # mean_length = 40/3, normalized = [10/13.33, 20/13.33, 10/13.33] = [0.75, 1.5, 0.75] + # factor = (1 + 0.33 * normalized)^-1 + # rewards_shaped = rewards * factor, then subtract mean expected = torch.tensor([[0.20915856, -0.25799648, 0.04883792]]) assert torch.allclose(result.advantages, expected, atol=1e-6) - assert torch.allclose(result.advantages.mean(dim=1), torch.zeros(1), atol=1e-6) def test_compute_advantages_with_config(): @@ -42,9 +86,7 @@ def test_compute_advantages_with_config(): result = compute_advantages(rewards, lengths, samples_per_problem=3, advantage_config=DefaultAdvantageConfig()) assert len(result) == 6 - # First 3 should sum to ~0 (mean subtracted) assert abs(sum(result[:3])) < 1e-5 - # Last 3 should sum to ~0 assert abs(sum(result[3:])) < 1e-5 @@ -54,7 +96,6 @@ def test_compute_advantages_without_config(): result = compute_advantages(rewards, lengths, samples_per_problem=3, advantage_config=None) - # Without config, returns raw rewards assert result == rewards @@ -72,7 +113,6 @@ def test_setup_advantage_fn_with_custom_config(): result = advantage_fn(inputs) assert isinstance(result, AdvantageOutputs) - # Dummy just multiplies rewards by scale assert torch.allclose(result.advantages, torch.tensor([[2.0, 1.0, 1.6]]))