Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
parameterized
pytest
pre-commit
py-spy
Expand Down
436 changes: 436 additions & 0 deletions tests/trainer/ppo/test_filter_zero_adv_on_cpu.py

Large diffs are not rendered by default.

40 changes: 40 additions & 0 deletions tests/trainer/ppo/test_metric_utils_on_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import numpy as np
import torch
from parameterized import parameterized

from verl.trainer.ppo.metric_utils import (
bootstrap_metric,
Expand Down Expand Up @@ -542,5 +543,44 @@ def test_process_validation_metrics_with_pred(self):
# depending on the random sampling, so we don't check the exact value


class TestZeroAdvMetrics(unittest.TestCase):
"""Tests for zero-advantage metrics in compute_data_metrics."""

def _make_batch(self, advantages):
batch = MagicMock()
bs, seq = advantages.shape
batch.batch = {
"advantages": advantages,
"attention_mask": torch.ones((bs, seq * 2)),
"response_mask": torch.ones((bs, seq)),
"responses": torch.zeros((bs, seq)),
"returns": torch.ones((bs, seq)),
"token_level_rewards": torch.ones((bs, seq)),
"token_level_scores": torch.ones((bs, seq)),
}
return batch

@parameterized.expand(
(
# (name, advantages, expected_count, expected_ratio)
("all_nonzero_2", ((0.1, 0.2), (0.3, 0.4)), 0, 0.0),
("some_zero_2", ((0.0, 0.0), (0.3, 0.4)), 1, 0.5),
("all_zero_2", ((0.0, 0.0), (0.0, 0.0)), 2, 1.0),
("all_zero_1", ((0.0, 0.0),), 1, 1.0),
("some_zero_3", ((0.0, 0.0), (0.3, 0.4), (0.5, 0.6)), 1, 1.0 / 3),
("some_zero_4", ((0.0, 0.0), (0.0, 0.0), (0.3, 0.4), (0.5, 0.6)), 2, 0.5),
("below_eps", ((1e-9, 1e-10), (0.3, 0.4)), 1, 0.5),
("at_eps", ((1e-8, 0.0), (0.3, 0.4)), 0, 0.0),
("above_eps", ((1e-7, 0.0), (0.3, 0.4)), 0, 0.0),
)
)
def test_zero_adv_count_and_ratio(self, _name, advantages, expected_count, expected_ratio):
batch = self._make_batch(torch.tensor(advantages))
metrics = compute_data_metrics(batch, use_critic=False)

self.assertEqual(metrics["critic/advantages/zero_adv_count"], expected_count)
self.assertAlmostEqual(metrics["critic/advantages/zero_adv_ratio"], expected_ratio)


if __name__ == "__main__":
unittest.main()
3 changes: 3 additions & 0 deletions verl/trainer/config/_generated_ppo_torchtitan_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,9 @@ algorithm:
kl_coef: 0.001
horizon: 10000
target_kl: 0.1
filter_zero_adv:
enable: false
match_loss_curve: true
use_pf_ppo: false
pf_ppo:
reweight_method: pow
Expand Down
3 changes: 3 additions & 0 deletions verl/trainer/config/_generated_ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,9 @@ algorithm:
kl_coef: 0.001
horizon: 10000
target_kl: 0.1
filter_zero_adv:
enable: false
match_loss_curve: true
use_pf_ppo: false
pf_ppo:
reweight_method: pow
Expand Down
3 changes: 3 additions & 0 deletions verl/trainer/config/_generated_ppo_veomni_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,9 @@ algorithm:
kl_coef: 0.001
horizon: 10000
target_kl: 0.1
filter_zero_adv:
enable: false
match_loss_curve: true
use_pf_ppo: false
pf_ppo:
reweight_method: pow
Expand Down
21 changes: 20 additions & 1 deletion verl/trainer/config/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from verl.base_config import BaseConfig

__all__ = ["AlgoConfig", "FilterGroupsConfig", "KLControlConfig", "RolloutCorrectionConfig"]
__all__ = ["AlgoConfig", "FilterGroupsConfig", "FilterZeroAdvConfig", "KLControlConfig", "RolloutCorrectionConfig"]


@dataclass
Expand Down Expand Up @@ -56,6 +56,22 @@ class FilterGroupsConfig(BaseConfig):
max_num_gen_batches: int = 0


@dataclass
class FilterZeroAdvConfig(BaseConfig):
"""Configuration for filter_zero_adv (skip zero-advantage responses in actor update).

Args:
enable (bool): Whether to enable filtering. Responses in all-same-reward groups
contribute no policy gradient; filtering them saves fwd/bwd compute.
match_loss_curve (bool): Whether to add ghost optimizer.step() calls to preserve
the same number of optimizer updates as unfiltered training, matching the
baseline convergence curve.
"""

enable: bool = False
match_loss_curve: bool = True


@dataclass
class RolloutCorrectionConfig(BaseConfig):
"""Configuration for Rollout Correction (addresses off-policy issues in RL training).
Expand Down Expand Up @@ -630,6 +646,8 @@ class AlgoConfig(BaseConfig):
use_pf_ppo (bool): Whether to enable preference feedback PPO.
pf_ppo (dict[str, Any]): Preference feedback PPO settings.
filter_groups (Optional[FilterGroupsConfig]): Filter groups configuration, used in DAPO and Entropy
filter_zero_adv (FilterZeroAdvConfig): Configuration for skipping zero-advantage responses
in actor update. See FilterZeroAdvConfig for details.
rollout_correction (Optional[RolloutCorrectionConfig]): Rollout Correction configuration.
Addresses off-policy issues from policy mismatch, model staleness, and general distribution shifts.

Expand Down Expand Up @@ -658,6 +676,7 @@ class AlgoConfig(BaseConfig):
use_pf_ppo: bool = False
pf_ppo: dict[str, Any] = field(default_factory=dict)
filter_groups: Optional[FilterGroupsConfig] = None
filter_zero_adv: FilterZeroAdvConfig = field(default_factory=FilterZeroAdvConfig)
# Rollout Correction: corrects off-policy issues (policy mismatch, model staleness, distribution shifts)
# Set to None to disable, use RolloutCorrectionConfig presets (e.g., .tis(), .mis()), or pass dict
rollout_correction: Optional[RolloutCorrectionConfig] = None
Expand Down
10 changes: 10 additions & 0 deletions verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,16 @@ algorithm:
# Target KL divergence (used for adaptive controller)
target_kl: 0.1

# Skip zero-advantage responses in actor update to save compute.
# Responses in all-same-reward groups contribute no policy gradient.
filter_zero_adv:

# Whether to enable filtering
enable: False

# Whether to add ghost optimizer.step() to match baseline convergence curve
match_loss_curve: True

# Whether to enable preference feedback PPO
use_pf_ppo: False

Expand Down
13 changes: 9 additions & 4 deletions verl/trainer/ppo/core_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@
tuple[torch.Tensor, dict[str, Any]],
]

LOSS_AGG_SEQ_MEAN_TOKEN_MEAN = "seq-mean-token-mean"
LOSS_AGG_SEQ_MEAN_TOKEN_SUM = "seq-mean-token-sum"
LOSS_AGG_SEQ_MEAN_TOKEN_SUM_NORM = "seq-mean-token-sum-norm"
LOSS_AGG_TOKEN_MEAN = "token-mean"

POLICY_LOSS_REGISTRY: dict[str, PolicyLossFn] = {}


Expand Down Expand Up @@ -1165,26 +1170,26 @@ def agg_loss(
loss: `a scalar torch.Tensor`
aggregated loss
"""
if loss_agg_mode == "token-mean":
if loss_agg_mode == LOSS_AGG_TOKEN_MEAN:
if batch_num_tokens is None:
if dp_size > 1:
raise ValueError("(global) batch_num_tokens is required when dp_size > 1")
batch_num_tokens = loss_mask.sum()
loss = verl_F.masked_sum(loss_mat, loss_mask) / batch_num_tokens * dp_size
elif loss_agg_mode in ["seq-mean-token-sum", "seq-mean-token-sum-norm"]:
elif loss_agg_mode in (LOSS_AGG_SEQ_MEAN_TOKEN_SUM, LOSS_AGG_SEQ_MEAN_TOKEN_SUM_NORM):
seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) # token-sum
seq_mask = (torch.sum(loss_mask, dim=-1) > 0).float() # exclude fully masked sequences
if global_batch_size is None:
if dp_size > 1:
raise ValueError("global_batch_size is required when dp_size > 1")
global_batch_size = seq_mask.sum()
loss = verl_F.masked_sum(seq_losses, seq_mask) / global_batch_size * dp_size # seq-mean
if loss_agg_mode == "seq-mean-token-sum-norm":
if loss_agg_mode == LOSS_AGG_SEQ_MEAN_TOKEN_SUM_NORM:
if loss_scale_factor is None:
horizon = loss_mask.shape[-1]
loss_scale_factor = horizon
loss /= loss_scale_factor
elif loss_agg_mode == "seq-mean-token-mean":
elif loss_agg_mode == LOSS_AGG_SEQ_MEAN_TOKEN_MEAN:
seq_mask = torch.sum(loss_mask, dim=-1) # per-sequence token count
seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / (seq_mask + 1e-8) # token-mean
seq_mask = (seq_mask > 0).float() # exclude fully masked sequences
Expand Down
141 changes: 138 additions & 3 deletions verl/trainer/ppo/metric_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,135 @@
from verl import DataProto
from verl.utils.import_utils import deprecated

KEY_ADVANTAGES = "advantages"
KEY_ATTENTION_MASK = "attention_mask"
KEY_FILTER_ZERO_ADV_CONFIG = "filter_zero_adv_config"
KEY_NUM_SEQS_CORRECTION_FACTOR = "batch_num_seqs_correction_factor"
KEY_NUM_TOKENS_CORRECTION_FACTOR = "batch_num_tokens_correction_factor"
KEY_ORIGINAL_BATCH_SIZE_PER_DP_GROUP = "original_batch_size_per_dp_group"
KEY_RESPONSE_MASK = "response_mask"


ZERO_ADV_EPS = 1e-8


def ceildiv(a: int, b: int) -> int:
return -(-a // b)


def maybe_add_corrected_mfu(metrics: dict, meta_info: dict) -> None:
"""Add corrected MFU metric when filter_zero_adv is active.

When filter_zero_adv is active, perf/mfu/actor is inflated: the FLOPS
numerator still reflects the original (unfiltered) token count while
time is reduced from processing fewer samples. This adds
perf/mfu/actor_corrected, which scales MFU by
(filtered_tokens / original_tokens) to match the actual tokens
processed, roughly matching baseline MFU.
"""
token_correction = meta_info.get(KEY_NUM_TOKENS_CORRECTION_FACTOR, None)
if token_correction is not None:
metrics["perf/mfu/actor_corrected"] = metrics["perf/mfu/actor"] * token_correction


def _select_shortest(batch: DataProto, indices: torch.Tensor, k: int) -> list[int]:
"""Select the k shortest samples by attention_mask length from the given indices."""
seq_lens = batch.batch[KEY_ATTENTION_MASK][indices].sum(dim=-1)
_, topk_idx = seq_lens.topk(k, largest=False)
return indices[topk_idx].tolist()


def filter_zero_adv_batch(batch: DataProto, dp_size: int, ppo_mini_batch_size: int = 0) -> tuple[DataProto, dict]:
"""Filter out zero-advantage responses to skip wasted actor compute.

Responses in all-same-reward groups have advantage≈0 and contribute no policy gradient.
Pads with shortest zero-adv samples to ensure divisibility by the alignment unit.

When ppo_mini_batch_size > 0, pads to dp_size * K (K = original mini-batch count)
so sequences distribute evenly across DP groups and mini-batches. Otherwise
pads to dp_size only.

When all samples have zero advantage, keeps alignment-unit shortest samples so the
optimizer/LR-scheduler still steps (gradients will be ~0).

Args:
batch: Full training batch with "advantages", "response_mask", "attention_mask".
dp_size: Data parallel size for alignment.
ppo_mini_batch_size: Mini-batch size for K computation. 0 = pad to dp_size only.

Returns:
(filtered_batch, metrics): filtered_batch always has ≥ dp_size samples.
"""
response_mask = batch.batch[KEY_RESPONSE_MASK]
max_abs_adv = (batch.batch[KEY_ADVANTAGES].abs() * response_mask).max(dim=-1).values
num_total = max_abs_adv.numel()
bs_per_dp = ceildiv(num_total, dp_size)

_nonzero_mask = max_abs_adv >= ZERO_ADV_EPS
nonzero_indices = torch.where(_nonzero_mask)[0].tolist()
num_nonzero = len(nonzero_indices)

zero_idx_tensor = torch.where(~_nonzero_mask)[0]
num_zeros = zero_idx_tensor.numel()

# Alignment unit: dp_size * K when distributing evenly across mini-batches,
# otherwise dp_size only. Capped by num_nonzero to ensure each mini-batch
# gets at least one nonzero sample per DP group.
if ppo_mini_batch_size > 0:
k_original = ceildiv(bs_per_dp, ppo_mini_batch_size)
align_opt_steps = min(k_original, max(1, ceildiv(num_nonzero, dp_size)))
else:
align_opt_steps = 1
align = dp_size * align_opt_steps

original_num_tokens = response_mask.sum().item()
if original_num_tokens == 0:
# Empty batch: skip filtering.
selected = None
elif num_nonzero == 0:
# All zero-adv: keep align shortest for LR schedule continuity (~0 gradient).
selected = _select_shortest(batch, zero_idx_tensor, align)
else:
num_pad = (-num_nonzero) % align
if num_zeros <= num_pad:
# Not enough zero-adv samples to align — skip filtering, use full batch.
selected = None
elif num_pad > 0:
selected = nonzero_indices + _select_shortest(batch, zero_idx_tensor, num_pad)
else:
selected = nonzero_indices

if selected is None:
num_selected = num_total
filtered_batch = batch
else:
num_selected = len(selected)
assert num_selected != num_total, f"Filtering was a no-op but selected is not None: {num_selected=}"

filtered_batch = batch[selected]
filtered_batch.meta_info.update(
{
KEY_ORIGINAL_BATCH_SIZE_PER_DP_GROUP: bs_per_dp, # per-GPU (matches ppo_mini_batch_size)
# Loss normalization corrections: agg_loss divides by local token/seq count,
# but we need to normalize by the original (pre-filter) counts so the
# gradient magnitude matches the unfiltered baseline.
KEY_NUM_TOKENS_CORRECTION_FACTOR: (
filtered_batch.batch[KEY_RESPONSE_MASK].sum().item() / original_num_tokens
),
KEY_NUM_SEQS_CORRECTION_FACTOR: num_selected / num_total,
}
)
num_padded = num_selected - num_nonzero

metrics = {
"actor/filter_zero_adv/num_nonzero": num_nonzero,
"actor/filter_zero_adv/num_padded": num_padded,
"actor/filter_zero_adv/num_kept": num_selected,
"actor/filter_zero_adv/num_total": num_total,
"actor/filter_zero_adv/kept_ratio": num_selected / num_total if num_total > 0 else 0.0,
}
return filtered_batch, metrics


@deprecated("verl.utils.metric.reduce_metrics")
def reduce_metrics(metrics: dict[str, list[Any]]) -> dict[str, Any]:
Expand Down Expand Up @@ -105,13 +234,13 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> dict[str,
sequence_score = batch.batch["token_level_scores"].sum(-1)
sequence_reward = batch.batch["token_level_rewards"].sum(-1)

advantages = batch.batch["advantages"]
advantages = batch.batch[KEY_ADVANTAGES]
returns = batch.batch["returns"]

max_response_length = batch.batch["responses"].shape[-1]

prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool()
response_mask = batch.batch["response_mask"].bool()
response_mask = batch.batch[KEY_RESPONSE_MASK].bool()

max_prompt_length = prompt_mask.size(-1)

Expand All @@ -136,6 +265,10 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> dict[str,
valid_adv = torch.masked_select(advantages, response_mask)
valid_returns = torch.masked_select(returns, response_mask)

# Per-response zero-advantage ratio: responses whose advantage is zero contribute no policy gradient.
max_abs_adv = (advantages.abs() * response_mask).max(dim=-1).values # (bs,)
num_zero_adv = (max_abs_adv < ZERO_ADV_EPS).sum().item()
num_responses = max_abs_adv.numel()
if use_critic:
values = batch.batch["values"]
valid_values = torch.masked_select(values, response_mask)
Expand Down Expand Up @@ -170,6 +303,8 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> dict[str,
"critic/advantages/mean": torch.mean(valid_adv).detach().item(),
"critic/advantages/max": torch.max(valid_adv).detach().item(),
"critic/advantages/min": torch.min(valid_adv).detach().item(),
"critic/advantages/zero_adv_count": num_zero_adv,
"critic/advantages/zero_adv_ratio": num_zero_adv / num_responses if num_responses > 0 else 0.0,
# returns
"critic/returns/mean": torch.mean(valid_returns).detach().item(),
"critic/returns/max": torch.max(valid_returns).detach().item(),
Expand Down Expand Up @@ -344,7 +479,7 @@ def compute_variance_proxy_metrics(batch: DataProto, gradient_norm: float = None
# Note: IS weight statistics and mismatch metrics are logged in ray_trainer.py

# Get scalar advantages (mean over timesteps)
advantages = batch.batch["advantages"]
advantages = batch.batch[KEY_ADVANTAGES]
# Compute mean advantage per trajectory using masked_mean
advantages_scalar = verl_F.masked_mean(advantages, response_mask, axis=-1)

Expand Down
Loading