Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions src/prime_rl/configs/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,10 +562,21 @@ class DefaultAdvantageConfig(BaseModel):
model_config = ConfigDict(extra="forbid")

type: Literal["default"] = "default"
length_shaping: Annotated[
Literal["off", "brevity_bonus", "gr3"],
Field(
description=(
"Length shaping strategy for reward adjustment. "
"'brevity_bonus': attenuate correct rollouts by L_min/L_i (shortest correct keeps 1). "
"'gr3': multiplicatively scale all rollouts by (1 + alpha * L_i / L_mean)^-1. "
"'off': no length shaping."
)
),
] = "off"
length_shaping_alpha: Annotated[
float | None,
Field(description="Penalty coefficient for Group Relative Reward Rescaling (GR³). Recommended value: 0.33"),
] = None
float,
Field(description="Coefficient for gr3 length shaping. Only used when length_shaping='gr3'."),
] = 0.33


class CustomAdvantageConfig(BaseModel):
Expand Down Expand Up @@ -996,13 +1007,6 @@ def validate_verification_config(self):
)
return self

@model_validator(mode="after")
def validate_length_shaping_requires_online_difficulty_filtering(self):
if isinstance(self.advantage, DefaultAdvantageConfig) and self.advantage.length_shaping_alpha is not None:
if not self.buffer.online_difficulty_filtering:
raise ValueError("Group Relative Reward (GR³) scaling requires online difficulty filtering")
return self

@model_validator(mode="after")
def auto_setup_bench(self):
if self.bench:
Expand Down
23 changes: 18 additions & 5 deletions src/prime_rl/orchestrator/advantage.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,28 @@ def my_advantage(inputs: AdvantageInputs, **kwargs) -> AdvantageOutputs:

def default_advantage_fn(
inputs: AdvantageInputs,
length_shaping_alpha: float | None = None,
length_shaping: str = "off",
length_shaping_alpha: float = 0.33,
) -> AdvantageOutputs:
"""Default GRPO advantage: reward minus per-problem baseline."""
rewards = inputs.rewards
completion_lengths = inputs.completion_lengths.to(dtype=rewards.dtype)

if length_shaping_alpha is not None:
completion_lengths = inputs.completion_lengths.to(dtype=rewards.dtype)
if length_shaping == "brevity_bonus":
correct_mask = rewards >= 1.0

# Shortest correct completion per problem (inf where no rollout is correct)
lengths_masked = completion_lengths.masked_fill(~correct_mask, float("inf"))
min_correct_length = lengths_masked.min(dim=1, keepdim=True).values

# Correct rollouts: reward * L_min/L_i (shortest correct keeps 1, longer ones attenuated)
shaped = rewards * (min_correct_length / completion_lengths)
rewards = torch.where(correct_mask, shaped, rewards)

elif length_shaping == "gr3":
lengths_normalized = completion_lengths / completion_lengths.mean(dim=1, keepdim=True)
length_shaping = (1 + length_shaping_alpha * lengths_normalized) ** -1
rewards = rewards * length_shaping
rewards = rewards * (1 + length_shaping_alpha * lengths_normalized) ** -1

baseline = rewards.mean(dim=1, keepdim=True)

return AdvantageOutputs(advantages=rewards - baseline)
Expand All @@ -64,6 +76,7 @@ def advantage_fn(inputs: AdvantageInputs) -> AdvantageOutputs:
def advantage_fn(inputs: AdvantageInputs) -> AdvantageOutputs:
return default_advantage_fn(
inputs,
length_shaping=config.length_shaping,
length_shaping_alpha=config.length_shaping_alpha,
)

Expand Down
58 changes: 49 additions & 9 deletions tests/unit/orchestrator/test_advantage.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,65 @@ def test_default_advantage_fn_simple_mean():
result = default_advantage_fn(inputs)

assert result.advantages.shape == (2, 3)
# Check that mean is subtracted per row
assert torch.allclose(result.advantages.mean(dim=1), torch.zeros(2), atol=1e-6)


def test_default_advantage_fn_gr3_length_shaping():
def test_length_shaping_only_penalizes_correct_rollouts():
"""Correct rollouts get attenuated by L_min/L_i; incorrect ones are unchanged."""
inputs = AdvantageInputs(
rewards=torch.tensor([[1.0, 1.0, 0.0, 1.0]]),
completion_lengths=torch.tensor([[10, 30, 20, 20]]),
)
result = default_advantage_fn(inputs, length_shaping="brevity_bonus")

# min_correct = 10
# shaped: [1*10/10, 1*10/30, 0, 1*10/20] = [1.0, 1/3, 0.0, 0.5]
shaped_rewards = torch.tensor([1.0, 1.0 / 3, 0.0, 0.5])
expected = shaped_rewards - shaped_rewards.mean()

assert torch.allclose(result.advantages, expected.unsqueeze(0), atol=1e-6)


def test_length_shaping_shortest_correct_keeps_full_reward():
"""The shortest correct rollout keeps reward=1."""
inputs = AdvantageInputs(
rewards=torch.tensor([[1.0, 1.0, 1.0]]),
completion_lengths=torch.tensor([[10, 20, 40]]),
)
result = default_advantage_fn(inputs, length_shaping="brevity_bonus")

# shaped: [1*10/10, 1*10/20, 1*10/40] = [1.0, 0.5, 0.25]
shaped_rewards = torch.tensor([1.0, 0.5, 0.25])
expected = shaped_rewards - shaped_rewards.mean()

assert torch.allclose(result.advantages, expected.unsqueeze(0), atol=1e-6)


def test_length_shaping_no_correct_rollouts():
"""When no rollout is correct, length shaping has no effect."""
inputs = AdvantageInputs(
rewards=torch.tensor([[0.0, 0.0, 0.0]]),
completion_lengths=torch.tensor([[10, 20, 15]]),
)
result_with = default_advantage_fn(inputs, length_shaping="brevity_bonus")
result_without = default_advantage_fn(inputs)

assert torch.allclose(result_with.advantages, result_without.advantages, atol=1e-6)


def test_gr3_length_shaping():
"""GR³: multiplicative shaping on all rollouts relative to mean length."""
inputs = AdvantageInputs(
rewards=torch.tensor([[1.0, 0.5, 0.8]]),
completion_lengths=torch.tensor([[10, 20, 10]]),
)
result = default_advantage_fn(inputs, length_shaping="gr3", length_shaping_alpha=0.33)

result = default_advantage_fn(inputs, length_shaping_alpha=0.33)

# mean_length = 40/3, normalized = [10/13.33, 20/13.33, 10/13.33] = [0.75, 1.5, 0.75]
# factor = (1 + 0.33 * normalized)^-1
# rewards_shaped = rewards * factor, then subtract mean
expected = torch.tensor([[0.20915856, -0.25799648, 0.04883792]])
assert torch.allclose(result.advantages, expected, atol=1e-6)
assert torch.allclose(result.advantages.mean(dim=1), torch.zeros(1), atol=1e-6)


def test_compute_advantages_with_config():
Expand All @@ -42,9 +86,7 @@ def test_compute_advantages_with_config():
result = compute_advantages(rewards, lengths, samples_per_problem=3, advantage_config=DefaultAdvantageConfig())

assert len(result) == 6
# First 3 should sum to ~0 (mean subtracted)
assert abs(sum(result[:3])) < 1e-5
# Last 3 should sum to ~0
assert abs(sum(result[3:])) < 1e-5


Expand All @@ -54,7 +96,6 @@ def test_compute_advantages_without_config():

result = compute_advantages(rewards, lengths, samples_per_problem=3, advantage_config=None)

# Without config, returns raw rewards
assert result == rewards


Expand All @@ -72,7 +113,6 @@ def test_setup_advantage_fn_with_custom_config():

result = advantage_fn(inputs)
assert isinstance(result, AdvantageOutputs)
# Dummy just multiplies rewards by scale
assert torch.allclose(result.advantages, torch.tensor([[2.0, 1.0, 1.6]]))


Expand Down
Loading