diff --git a/requirements-test.txt b/requirements-test.txt index 92b4996eeb3..e2f6b543417 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,3 +1,4 @@ +parameterized pytest pre-commit py-spy diff --git a/tests/trainer/ppo/test_filter_zero_adv_on_cpu.py b/tests/trainer/ppo/test_filter_zero_adv_on_cpu.py new file mode 100644 index 00000000000..88a73976d85 --- /dev/null +++ b/tests/trainer/ppo/test_filter_zero_adv_on_cpu.py @@ -0,0 +1,436 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized +from tensordict import TensorDict + +from verl import DataProto +from verl.trainer.ppo.metric_utils import ( + KEY_NUM_SEQS_CORRECTION_FACTOR, + KEY_NUM_TOKENS_CORRECTION_FACTOR, + KEY_ORIGINAL_BATCH_SIZE_PER_DP_GROUP, + ZERO_ADV_EPS, + ceildiv, + filter_zero_adv_batch, + maybe_add_corrected_mfu, +) + +EXPECTED_METRIC_KEYS = ( + "actor/filter_zero_adv/kept_ratio", + "actor/filter_zero_adv/num_kept", + "actor/filter_zero_adv/num_nonzero", + "actor/filter_zero_adv/num_padded", + "actor/filter_zero_adv/num_total", +) + + +def _make_batch(num_nonzero, num_zero, seq_len, attention_lengths=None): + """Helper to construct a DataProto batch for filter_zero_adv_batch tests. + + Args: + num_nonzero: Number of sequences with nonzero advantage. + num_zero: Number of sequences with zero advantage. + seq_len: Sequence length for all tensors. + attention_lengths: Optional tuple of attention lengths per sequence. + If None, all sequences have full attention. + """ + bs = num_nonzero + num_zero + advantages = torch.zeros(bs, seq_len) + if num_nonzero > 0: + advantages[:num_nonzero] = torch.randn(num_nonzero, seq_len).abs() + 1.0 + response_mask = torch.ones(bs, seq_len) + + if attention_lengths is not None: + attention_mask = torch.zeros(bs, seq_len) + for i, length in enumerate(attention_lengths): + attention_mask[i, :length] = 1.0 + else: + attention_mask = torch.ones(bs, seq_len) + + td = TensorDict( + { + "advantages": advantages, + "attention_mask": attention_mask, + "response_mask": response_mask, + }, + batch_size=(bs,), + ) + return DataProto(batch=td) + + +class TestCeildiv(unittest.TestCase): + @parameterized.expand( + ( + (1, 1, 1), + (4, 2, 2), + (5, 2, 3), + (10, 3, 4), + (10, 4, 3), + (7, 1, 7), + (0, 5, 0), + (1, 5, 1), + (128, 128, 1), + (129, 128, 2), + ) + ) + def test_ceildiv(self, a, b, expected): + self.assertEqual(ceildiv(a, b), expected) + + +class TestMaybeAddCorrectedMfu(unittest.TestCase): + """Tests for maybe_add_corrected_mfu in metric_utils.py.""" + + @parameterized.expand( + ( + ("absent", {}), + ("explicit_none", {KEY_NUM_TOKENS_CORRECTION_FACTOR: None}), + ) + ) + def test_no_correction(self, _name, meta_info): + """When correction factor is absent or None, no corrected metric is added.""" + metrics = {"perf/mfu/actor": 0.25, "perf/mfu/critic": 0.10} + maybe_add_corrected_mfu(metrics, meta_info) + self.assertEqual(sorted(metrics.keys()), ["perf/mfu/actor", "perf/mfu/critic"]) + self.assertAlmostEqual(metrics["perf/mfu/actor"], 0.25) + self.assertAlmostEqual(metrics["perf/mfu/critic"], 0.10) + + @parameterized.expand( + ( + ("full_batch", 1.0, 0.25, 0.25), + ("half_filtered", 0.5, 0.30, 0.15), + ("quarter_filtered", 0.25, 0.40, 0.10), + ("slight_filter", 0.8, 0.20, 0.16), + ("zero_mfu", 0.5, 0.0, 0.0), + ) + ) + def test_correction_applied(self, _name, token_correction, mfu, expected_corrected): + """Corrected MFU = original MFU * token_correction_factor; other keys unchanged.""" + metrics = {"perf/mfu/actor": mfu, "perf/mfu/critic": 0.10, "loss": 1.5} + meta_info = {KEY_NUM_TOKENS_CORRECTION_FACTOR: token_correction} + maybe_add_corrected_mfu(metrics, meta_info) + self.assertEqual( + sorted(metrics.keys()), ["loss", "perf/mfu/actor", "perf/mfu/actor_corrected", "perf/mfu/critic"] + ) + self.assertAlmostEqual(metrics["perf/mfu/actor_corrected"], expected_corrected, places=6) + self.assertAlmostEqual(metrics["perf/mfu/actor"], mfu) + self.assertAlmostEqual(metrics["perf/mfu/critic"], 0.10) + self.assertAlmostEqual(metrics["loss"], 1.5) + + def test_correction_overwrites_existing(self): + """When perf/mfu/actor_corrected already exists, it is overwritten.""" + metrics = {"perf/mfu/actor": 0.30, "perf/mfu/actor_corrected": 999.0} + meta_info = {KEY_NUM_TOKENS_CORRECTION_FACTOR: 0.5} + maybe_add_corrected_mfu(metrics, meta_info) + self.assertEqual(sorted(metrics.keys()), ["perf/mfu/actor", "perf/mfu/actor_corrected"]) + self.assertAlmostEqual(metrics["perf/mfu/actor_corrected"], 0.15, places=6) + self.assertAlmostEqual(metrics["perf/mfu/actor"], 0.30) + + +class TestFilterZeroAdvBatch(unittest.TestCase): + """Tests for filter_zero_adv_batch in metric_utils.py.""" + + # ------------------------------------------------------------------ # + # No-op: batch returned unchanged (select logic) + # ------------------------------------------------------------------ # + + @parameterized.expand( + ( + ("all_nonzero_dp2", 8, 0, 4, 2), + ("all_nonzero_dp4", 8, 0, 4, 4), + ("all_nonzero_dp8", 16, 0, 8, 8), + ) + ) + def test_filter_no_op_all_nonzero(self, _name, num_nonzero, num_zero, seq_len, dp_size): + """When all sequences have nonzero advantage, batch is returned unchanged.""" + batch = _make_batch(num_nonzero, num_zero, seq_len) + filtered, metrics = filter_zero_adv_batch(batch, dp_size) + + bs = num_nonzero + num_zero + self.assertIs(filtered, batch) + self.assertEqual(sorted(metrics.keys()), list(EXPECTED_METRIC_KEYS)) + self.assertEqual(metrics["actor/filter_zero_adv/num_total"], bs) + self.assertEqual(metrics["actor/filter_zero_adv/num_kept"], bs) + self.assertAlmostEqual(metrics["actor/filter_zero_adv/kept_ratio"], 1.0) + self.assertNotIn(KEY_ORIGINAL_BATCH_SIZE_PER_DP_GROUP, filtered.meta_info) + + @parameterized.expand( + ( + ("need_3_pad_have_1", 9, 1, 4, 4), + ("need_3_pad_have_2", 13, 2, 4, 8), + ("need_3_pad_have_3", 13, 3, 4, 8), + ("need_6_pad_have_6", 10, 6, 4, 8), + ("no_zeros_unaligned", 7, 0, 4, 4), + ) + ) + def test_filter_no_op_not_enough_zeros_to_pad(self, _name, num_nonzero, num_zero, seq_len, dp_size): + """When there aren't enough zero-adv samples for dp_size alignment, skip filtering.""" + batch = _make_batch(num_nonzero, num_zero, seq_len) + filtered, metrics = filter_zero_adv_batch(batch, dp_size) + + bs = num_nonzero + num_zero + self.assertIs(filtered, batch) + self.assertEqual(sorted(metrics.keys()), list(EXPECTED_METRIC_KEYS)) + self.assertEqual(metrics["actor/filter_zero_adv/num_total"], bs) + self.assertEqual(metrics["actor/filter_zero_adv/num_kept"], bs) + self.assertAlmostEqual(metrics["actor/filter_zero_adv/kept_ratio"], 1.0) + self.assertNotIn(KEY_ORIGINAL_BATCH_SIZE_PER_DP_GROUP, filtered.meta_info) + + def test_filter_no_op_empty_response_mask(self): + """When original_num_tokens == 0, skip filtering.""" + bs, seq_len, dp_size = 8, 4, 2 + td = TensorDict( + { + "advantages": torch.zeros(bs, seq_len), + "attention_mask": torch.ones(bs, seq_len), + "response_mask": torch.zeros(bs, seq_len), + }, + batch_size=(bs,), + ) + batch = DataProto(batch=td) + filtered, metrics = filter_zero_adv_batch(batch, dp_size) + + self.assertIs(filtered, batch) + self.assertEqual(sorted(metrics.keys()), list(EXPECTED_METRIC_KEYS)) + self.assertEqual(metrics["actor/filter_zero_adv/num_total"], bs) + self.assertEqual(metrics["actor/filter_zero_adv/num_kept"], bs) + self.assertAlmostEqual(metrics["actor/filter_zero_adv/kept_ratio"], 1.0) + self.assertNotIn(KEY_ORIGINAL_BATCH_SIZE_PER_DP_GROUP, filtered.meta_info) + + # ------------------------------------------------------------------ # + # Filtered: sequences are actually removed + # ------------------------------------------------------------------ # + + @parameterized.expand( + ( + # (name, num_nz, num_z, seq, dp, kept, padded, orig_bs_per_dp, seq_corr, token_corr) + ("6nz_4z_dp4", 6, 4, 4, 4, 8, 2, 3, 0.8, 0.8), + ("8nz_4z_dp4_aligned", 8, 4, 4, 4, 8, 0, 3, 2 / 3, 2 / 3), + ("3nz_7z_dp1", 3, 7, 4, 1, 3, 0, 10, 0.3, 0.3), + ("5nz_5z_dp2", 5, 5, 4, 2, 6, 1, 5, 0.6, 0.6), + ("4nz_12z_dp4", 4, 12, 8, 4, 4, 0, 4, 0.25, 0.25), + ("1nz_15z_dp4", 1, 15, 4, 4, 4, 3, 4, 0.25, 0.25), + ) + ) + def test_filtered_mixed( + self, + _name, + num_nonzero, + num_zero, + seq_len, + dp_size, + expected_kept, + expected_padded, + expected_original_bs_per_dp, + expected_seq_correction, + expected_token_correction, + ): + """Nonzero-adv sequences are kept, zero-adv removed (with dp_size padding).""" + batch = _make_batch(num_nonzero, num_zero, seq_len) + filtered, metrics = filter_zero_adv_batch(batch, dp_size) + + bs = num_nonzero + num_zero + # Metrics + self.assertEqual(sorted(metrics.keys()), list(EXPECTED_METRIC_KEYS)) + self.assertEqual(metrics["actor/filter_zero_adv/num_total"], bs) + self.assertEqual(metrics["actor/filter_zero_adv/num_nonzero"], num_nonzero) + self.assertEqual(metrics["actor/filter_zero_adv/num_kept"], expected_kept) + self.assertEqual(metrics["actor/filter_zero_adv/num_padded"], expected_padded) + self.assertAlmostEqual(metrics["actor/filter_zero_adv/kept_ratio"], expected_kept / bs) + self.assertEqual(filtered.batch["advantages"].shape[0], expected_kept) + # meta_info + self.assertEqual(filtered.meta_info[KEY_ORIGINAL_BATCH_SIZE_PER_DP_GROUP], expected_original_bs_per_dp) + self.assertAlmostEqual(filtered.meta_info[KEY_NUM_SEQS_CORRECTION_FACTOR], expected_seq_correction) + self.assertAlmostEqual( + filtered.meta_info[KEY_NUM_TOKENS_CORRECTION_FACTOR], expected_token_correction, places=6 + ) + + @parameterized.expand( + ( + # (name, bs, seq, dp, kept, orig_bs_per_dp, seq_corr, token_corr) + ("16x8_dp4", 16, 8, 4, 4, 4, 0.25, 0.25), + ("8x4_dp2", 8, 4, 2, 2, 4, 0.25, 0.25), + ("32x4_dp8", 32, 4, 8, 8, 4, 0.25, 0.25), + ) + ) + def test_filtered_all_zero_keeps_dp_size( + self, + _name, + bs, + seq_len, + dp_size, + expected_kept, + expected_original_bs_per_dp, + expected_seq_correction, + expected_token_correction, + ): + """When all sequences have zero advantage, keep dp_size shortest samples.""" + attention_lengths = tuple(range(1, bs + 1)) + batch = _make_batch(0, bs, seq_len, attention_lengths=attention_lengths) + + filtered, metrics = filter_zero_adv_batch(batch, dp_size) + + # Metrics + self.assertEqual(sorted(metrics.keys()), list(EXPECTED_METRIC_KEYS)) + self.assertEqual(metrics["actor/filter_zero_adv/num_total"], bs) + self.assertEqual(metrics["actor/filter_zero_adv/num_nonzero"], 0) + self.assertEqual(metrics["actor/filter_zero_adv/num_kept"], expected_kept) + self.assertAlmostEqual(metrics["actor/filter_zero_adv/kept_ratio"], expected_kept / bs) + self.assertEqual(filtered.batch["advantages"].shape[0], expected_kept) + kept_lengths = filtered.batch["attention_mask"].sum(dim=-1) + self.assertTrue((kept_lengths <= dp_size).all()) + # meta_info + self.assertEqual(filtered.meta_info[KEY_ORIGINAL_BATCH_SIZE_PER_DP_GROUP], expected_original_bs_per_dp) + self.assertAlmostEqual(filtered.meta_info[KEY_NUM_SEQS_CORRECTION_FACTOR], expected_seq_correction) + self.assertAlmostEqual( + filtered.meta_info[KEY_NUM_TOKENS_CORRECTION_FACTOR], expected_token_correction, places=6 + ) + + # ------------------------------------------------------------------ # + # With ppo_mini_batch_size: align to dp_size * K + # ------------------------------------------------------------------ # + + @parameterized.expand( + ( + # (name, num_nz, num_z, seq, dp, mini_bs, kept, padded, orig_bs_per_dp) + ("128_dp8_mini32_k1", 90, 38, 4, 8, 32, 96, 6, 16), + ("128_dp8_mini4_k4", 90, 38, 4, 8, 4, 96, 6, 16), + ("64_dp4_mini8_k2", 40, 24, 4, 4, 8, 40, 0, 16), + ("32_dp4_mini4_k2", 20, 12, 4, 4, 4, 24, 4, 8), + ("32_dp4_mini2_few_nz", 3, 29, 4, 4, 2, 4, 1, 8), + ) + ) + def test_filtered_with_ppo_mini_batch_size( + self, + _name, + num_nonzero, + num_zero, + seq_len, + dp_size, + ppo_mini_batch_size, + expected_kept, + expected_padded, + expected_original_bs_per_dp, + ): + """With ppo_mini_batch_size, align to dp_size * K for even mini-batch distribution.""" + batch = _make_batch(num_nonzero, num_zero, seq_len) + filtered, metrics = filter_zero_adv_batch(batch, dp_size, ppo_mini_batch_size=ppo_mini_batch_size) + + bs = num_nonzero + num_zero + self.assertEqual(metrics["actor/filter_zero_adv/num_total"], bs) + self.assertEqual(metrics["actor/filter_zero_adv/num_nonzero"], num_nonzero) + self.assertEqual(metrics["actor/filter_zero_adv/num_kept"], expected_kept) + self.assertEqual(metrics["actor/filter_zero_adv/num_padded"], expected_padded) + self.assertEqual(filtered.batch["advantages"].shape[0], expected_kept) + self.assertEqual(filtered.meta_info[KEY_ORIGINAL_BATCH_SIZE_PER_DP_GROUP], expected_original_bs_per_dp) + # Verify alignment: kept must be divisible by dp_size * align_opt_steps + bs_per_dp = ceildiv(bs, dp_size) + k_original = ceildiv(bs_per_dp, ppo_mini_batch_size) + align_opt_steps = min(k_original, max(1, ceildiv(num_nonzero, dp_size))) + self.assertEqual(expected_kept % (dp_size * align_opt_steps), 0) + + @parameterized.expand( + ( + # All zero with ppo_mini_batch_size: align to dp_size * K + # 32 total, dp=4, mini_bs=4 → bs_per_dp=8, K=2 + # 0 nz → align_opt_steps=min(2, max(1, ceil(0/4)))=min(2,1)=1, align=4 + ("32_dp4_mini4_all_zero", 32, 4, 4, 4), + # 16 total, dp=2, mini_bs=4 → bs_per_dp=8, K=2 + # 0 nz → align_opt_steps=1, align=2 + ("16_dp2_mini4_all_zero", 16, 4, 2, 4), + ) + ) + def test_filtered_all_zero_with_ppo_mini_batch_size(self, _name, bs, seq_len, dp_size, ppo_mini_batch_size): + """All-zero with ppo_mini_batch_size: align_opt_steps capped at 1, keep dp_size samples.""" + attention_lengths = tuple(range(1, bs + 1)) + batch = _make_batch(0, bs, seq_len, attention_lengths=attention_lengths) + + filtered, metrics = filter_zero_adv_batch(batch, dp_size, ppo_mini_batch_size=ppo_mini_batch_size) + + # align_opt_steps = min(K, max(1, ceil(0/dp))) = 1, so align = dp_size + self.assertEqual(metrics["actor/filter_zero_adv/num_kept"], dp_size) + self.assertEqual(filtered.batch["advantages"].shape[0], dp_size) + + # ------------------------------------------------------------------ # + # Edge cases: eps threshold, response_mask masking + # ------------------------------------------------------------------ # + + @parameterized.expand( + ( + ("half_eps", ZERO_ADV_EPS * 0.5, 0), + ("just_below_eps", ZERO_ADV_EPS * 0.99, 0), + ("at_eps", ZERO_ADV_EPS, 8), # >= check: exactly eps is nonzero + ("double_eps", ZERO_ADV_EPS * 2.0, 8), + ("well_above", 1e-4, 8), + ) + ) + def test_advantage_eps_threshold(self, _name, adv_value, expected_nonzero): + """Advantages are classified as zero/nonzero based on ZERO_ADV_EPS threshold.""" + bs, seq_len, dp_size = 8, 4, 2 + td = TensorDict( + { + "advantages": torch.full((bs, seq_len), adv_value), + "attention_mask": torch.ones(bs, seq_len), + "response_mask": torch.ones(bs, seq_len), + }, + batch_size=(bs,), + ) + batch = DataProto(batch=td) + _, metrics = filter_zero_adv_batch(batch, dp_size) + + self.assertEqual(metrics["actor/filter_zero_adv/num_nonzero"], expected_nonzero) + + def test_response_mask_determines_zero_adv(self): + """Only response tokens matter for zero-adv detection (advantages * response_mask).""" + bs, seq_len, dp_size = 4, 8, 2 + response_mask = torch.zeros(bs, seq_len) + response_mask[2:, 4:] = 1.0 # only last 2 sequences have response tokens + td = TensorDict( + { + "advantages": torch.ones(bs, seq_len) * 10.0, + "attention_mask": torch.ones(bs, seq_len), + "response_mask": response_mask, + }, + batch_size=(bs,), + ) + batch = DataProto(batch=td) + _, metrics = filter_zero_adv_batch(batch, dp_size) + + self.assertEqual(metrics["actor/filter_zero_adv/num_nonzero"], 2) + + # ------------------------------------------------------------------ # + # Padding selects shortest + # ------------------------------------------------------------------ # + + def test_padding_selects_shortest_zero_adv(self): + """When padding is needed, the shortest zero-adv samples (by attention_mask) are chosen.""" + seq_len, dp_size = 8, 4 + # 5 nonzero (len=8 each) + 5 zero with varying attention lengths + zero_adv_lengths = (2, 6, 1, 4, 3) + attention_lengths = (8, 8, 8, 8, 8) + zero_adv_lengths + batch = _make_batch(5, 5, seq_len, attention_lengths=attention_lengths) + + filtered, metrics = filter_zero_adv_batch(batch, dp_size) + + # 5 → next multiple of 4 = 8, so 3 pads needed (shortest: len=1, 2, 3) + self.assertEqual(metrics["actor/filter_zero_adv/num_kept"], 8) + kept_attn_lengths = sorted(filtered.batch["attention_mask"].sum(dim=-1).tolist()) + # 5 nonzero (len=8 each) + 3 shortest zero-adv pads (len=1, 2, 3) + self.assertEqual(kept_attn_lengths, [1.0, 2.0, 3.0, 8.0, 8.0, 8.0, 8.0, 8.0]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/trainer/ppo/test_metric_utils_on_cpu.py b/tests/trainer/ppo/test_metric_utils_on_cpu.py index 863843e29af..5c1f2fac9ff 100644 --- a/tests/trainer/ppo/test_metric_utils_on_cpu.py +++ b/tests/trainer/ppo/test_metric_utils_on_cpu.py @@ -20,6 +20,7 @@ import numpy as np import torch +from parameterized import parameterized from verl.trainer.ppo.metric_utils import ( bootstrap_metric, @@ -542,5 +543,44 @@ def test_process_validation_metrics_with_pred(self): # depending on the random sampling, so we don't check the exact value +class TestZeroAdvMetrics(unittest.TestCase): + """Tests for zero-advantage metrics in compute_data_metrics.""" + + def _make_batch(self, advantages): + batch = MagicMock() + bs, seq = advantages.shape + batch.batch = { + "advantages": advantages, + "attention_mask": torch.ones((bs, seq * 2)), + "response_mask": torch.ones((bs, seq)), + "responses": torch.zeros((bs, seq)), + "returns": torch.ones((bs, seq)), + "token_level_rewards": torch.ones((bs, seq)), + "token_level_scores": torch.ones((bs, seq)), + } + return batch + + @parameterized.expand( + ( + # (name, advantages, expected_count, expected_ratio) + ("all_nonzero_2", ((0.1, 0.2), (0.3, 0.4)), 0, 0.0), + ("some_zero_2", ((0.0, 0.0), (0.3, 0.4)), 1, 0.5), + ("all_zero_2", ((0.0, 0.0), (0.0, 0.0)), 2, 1.0), + ("all_zero_1", ((0.0, 0.0),), 1, 1.0), + ("some_zero_3", ((0.0, 0.0), (0.3, 0.4), (0.5, 0.6)), 1, 1.0 / 3), + ("some_zero_4", ((0.0, 0.0), (0.0, 0.0), (0.3, 0.4), (0.5, 0.6)), 2, 0.5), + ("below_eps", ((1e-9, 1e-10), (0.3, 0.4)), 1, 0.5), + ("at_eps", ((1e-8, 0.0), (0.3, 0.4)), 0, 0.0), + ("above_eps", ((1e-7, 0.0), (0.3, 0.4)), 0, 0.0), + ) + ) + def test_zero_adv_count_and_ratio(self, _name, advantages, expected_count, expected_ratio): + batch = self._make_batch(torch.tensor(advantages)) + metrics = compute_data_metrics(batch, use_critic=False) + + self.assertEqual(metrics["critic/advantages/zero_adv_count"], expected_count) + self.assertAlmostEqual(metrics["critic/advantages/zero_adv_ratio"], expected_ratio) + + if __name__ == "__main__": unittest.main() diff --git a/verl/trainer/config/_generated_ppo_torchtitan_trainer.yaml b/verl/trainer/config/_generated_ppo_torchtitan_trainer.yaml index 002ce98fb60..6d2ae59037e 100644 --- a/verl/trainer/config/_generated_ppo_torchtitan_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_torchtitan_trainer.yaml @@ -625,6 +625,9 @@ algorithm: kl_coef: 0.001 horizon: 10000 target_kl: 0.1 + filter_zero_adv: + enable: false + match_loss_curve: true use_pf_ppo: false pf_ppo: reweight_method: pow diff --git a/verl/trainer/config/_generated_ppo_trainer.yaml b/verl/trainer/config/_generated_ppo_trainer.yaml index fa1077eafa4..d080f673161 100644 --- a/verl/trainer/config/_generated_ppo_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_trainer.yaml @@ -657,6 +657,9 @@ algorithm: kl_coef: 0.001 horizon: 10000 target_kl: 0.1 + filter_zero_adv: + enable: false + match_loss_curve: true use_pf_ppo: false pf_ppo: reweight_method: pow diff --git a/verl/trainer/config/_generated_ppo_veomni_trainer.yaml b/verl/trainer/config/_generated_ppo_veomni_trainer.yaml index d8193518338..da8d649df92 100644 --- a/verl/trainer/config/_generated_ppo_veomni_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_veomni_trainer.yaml @@ -602,6 +602,9 @@ algorithm: kl_coef: 0.001 horizon: 10000 target_kl: 0.1 + filter_zero_adv: + enable: false + match_loss_curve: true use_pf_ppo: false pf_ppo: reweight_method: pow diff --git a/verl/trainer/config/algorithm.py b/verl/trainer/config/algorithm.py index a150ee63394..ebcb7e0ba25 100644 --- a/verl/trainer/config/algorithm.py +++ b/verl/trainer/config/algorithm.py @@ -17,7 +17,7 @@ from verl.base_config import BaseConfig -__all__ = ["AlgoConfig", "FilterGroupsConfig", "KLControlConfig", "RolloutCorrectionConfig"] +__all__ = ["AlgoConfig", "FilterGroupsConfig", "FilterZeroAdvConfig", "KLControlConfig", "RolloutCorrectionConfig"] @dataclass @@ -56,6 +56,22 @@ class FilterGroupsConfig(BaseConfig): max_num_gen_batches: int = 0 +@dataclass +class FilterZeroAdvConfig(BaseConfig): + """Configuration for filter_zero_adv (skip zero-advantage responses in actor update). + + Args: + enable (bool): Whether to enable filtering. Responses in all-same-reward groups + contribute no policy gradient; filtering them saves fwd/bwd compute. + match_loss_curve (bool): Whether to add ghost optimizer.step() calls to preserve + the same number of optimizer updates as unfiltered training, matching the + baseline convergence curve. + """ + + enable: bool = False + match_loss_curve: bool = True + + @dataclass class RolloutCorrectionConfig(BaseConfig): """Configuration for Rollout Correction (addresses off-policy issues in RL training). @@ -630,6 +646,8 @@ class AlgoConfig(BaseConfig): use_pf_ppo (bool): Whether to enable preference feedback PPO. pf_ppo (dict[str, Any]): Preference feedback PPO settings. filter_groups (Optional[FilterGroupsConfig]): Filter groups configuration, used in DAPO and Entropy + filter_zero_adv (FilterZeroAdvConfig): Configuration for skipping zero-advantage responses + in actor update. See FilterZeroAdvConfig for details. rollout_correction (Optional[RolloutCorrectionConfig]): Rollout Correction configuration. Addresses off-policy issues from policy mismatch, model staleness, and general distribution shifts. @@ -658,6 +676,7 @@ class AlgoConfig(BaseConfig): use_pf_ppo: bool = False pf_ppo: dict[str, Any] = field(default_factory=dict) filter_groups: Optional[FilterGroupsConfig] = None + filter_zero_adv: FilterZeroAdvConfig = field(default_factory=FilterZeroAdvConfig) # Rollout Correction: corrects off-policy issues (policy mismatch, model staleness, distribution shifts) # Set to None to disable, use RolloutCorrectionConfig presets (e.g., .tis(), .mis()), or pass dict rollout_correction: Optional[RolloutCorrectionConfig] = None diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 4aa18fbb980..cb3e7436872 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -102,6 +102,16 @@ algorithm: # Target KL divergence (used for adaptive controller) target_kl: 0.1 + # Skip zero-advantage responses in actor update to save compute. + # Responses in all-same-reward groups contribute no policy gradient. + filter_zero_adv: + + # Whether to enable filtering + enable: False + + # Whether to add ghost optimizer.step() to match baseline convergence curve + match_loss_curve: True + # Whether to enable preference feedback PPO use_pf_ppo: False diff --git a/verl/trainer/ppo/core_algos.py b/verl/trainer/ppo/core_algos.py index e777c0903c0..b26d8d62475 100644 --- a/verl/trainer/ppo/core_algos.py +++ b/verl/trainer/ppo/core_algos.py @@ -47,6 +47,11 @@ tuple[torch.Tensor, dict[str, Any]], ] +LOSS_AGG_SEQ_MEAN_TOKEN_MEAN = "seq-mean-token-mean" +LOSS_AGG_SEQ_MEAN_TOKEN_SUM = "seq-mean-token-sum" +LOSS_AGG_SEQ_MEAN_TOKEN_SUM_NORM = "seq-mean-token-sum-norm" +LOSS_AGG_TOKEN_MEAN = "token-mean" + POLICY_LOSS_REGISTRY: dict[str, PolicyLossFn] = {} @@ -1165,13 +1170,13 @@ def agg_loss( loss: `a scalar torch.Tensor` aggregated loss """ - if loss_agg_mode == "token-mean": + if loss_agg_mode == LOSS_AGG_TOKEN_MEAN: if batch_num_tokens is None: if dp_size > 1: raise ValueError("(global) batch_num_tokens is required when dp_size > 1") batch_num_tokens = loss_mask.sum() loss = verl_F.masked_sum(loss_mat, loss_mask) / batch_num_tokens * dp_size - elif loss_agg_mode in ["seq-mean-token-sum", "seq-mean-token-sum-norm"]: + elif loss_agg_mode in (LOSS_AGG_SEQ_MEAN_TOKEN_SUM, LOSS_AGG_SEQ_MEAN_TOKEN_SUM_NORM): seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) # token-sum seq_mask = (torch.sum(loss_mask, dim=-1) > 0).float() # exclude fully masked sequences if global_batch_size is None: @@ -1179,12 +1184,12 @@ def agg_loss( raise ValueError("global_batch_size is required when dp_size > 1") global_batch_size = seq_mask.sum() loss = verl_F.masked_sum(seq_losses, seq_mask) / global_batch_size * dp_size # seq-mean - if loss_agg_mode == "seq-mean-token-sum-norm": + if loss_agg_mode == LOSS_AGG_SEQ_MEAN_TOKEN_SUM_NORM: if loss_scale_factor is None: horizon = loss_mask.shape[-1] loss_scale_factor = horizon loss /= loss_scale_factor - elif loss_agg_mode == "seq-mean-token-mean": + elif loss_agg_mode == LOSS_AGG_SEQ_MEAN_TOKEN_MEAN: seq_mask = torch.sum(loss_mask, dim=-1) # per-sequence token count seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / (seq_mask + 1e-8) # token-mean seq_mask = (seq_mask > 0).float() # exclude fully masked sequences diff --git a/verl/trainer/ppo/metric_utils.py b/verl/trainer/ppo/metric_utils.py index 4dd7d2d00a5..618941f4d62 100644 --- a/verl/trainer/ppo/metric_utils.py +++ b/verl/trainer/ppo/metric_utils.py @@ -26,6 +26,135 @@ from verl import DataProto from verl.utils.import_utils import deprecated +KEY_ADVANTAGES = "advantages" +KEY_ATTENTION_MASK = "attention_mask" +KEY_FILTER_ZERO_ADV_CONFIG = "filter_zero_adv_config" +KEY_NUM_SEQS_CORRECTION_FACTOR = "batch_num_seqs_correction_factor" +KEY_NUM_TOKENS_CORRECTION_FACTOR = "batch_num_tokens_correction_factor" +KEY_ORIGINAL_BATCH_SIZE_PER_DP_GROUP = "original_batch_size_per_dp_group" +KEY_RESPONSE_MASK = "response_mask" + + +ZERO_ADV_EPS = 1e-8 + + +def ceildiv(a: int, b: int) -> int: + return -(-a // b) + + +def maybe_add_corrected_mfu(metrics: dict, meta_info: dict) -> None: + """Add corrected MFU metric when filter_zero_adv is active. + + When filter_zero_adv is active, perf/mfu/actor is inflated: the FLOPS + numerator still reflects the original (unfiltered) token count while + time is reduced from processing fewer samples. This adds + perf/mfu/actor_corrected, which scales MFU by + (filtered_tokens / original_tokens) to match the actual tokens + processed, roughly matching baseline MFU. + """ + token_correction = meta_info.get(KEY_NUM_TOKENS_CORRECTION_FACTOR, None) + if token_correction is not None: + metrics["perf/mfu/actor_corrected"] = metrics["perf/mfu/actor"] * token_correction + + +def _select_shortest(batch: DataProto, indices: torch.Tensor, k: int) -> list[int]: + """Select the k shortest samples by attention_mask length from the given indices.""" + seq_lens = batch.batch[KEY_ATTENTION_MASK][indices].sum(dim=-1) + _, topk_idx = seq_lens.topk(k, largest=False) + return indices[topk_idx].tolist() + + +def filter_zero_adv_batch(batch: DataProto, dp_size: int, ppo_mini_batch_size: int = 0) -> tuple[DataProto, dict]: + """Filter out zero-advantage responses to skip wasted actor compute. + + Responses in all-same-reward groups have advantage≈0 and contribute no policy gradient. + Pads with shortest zero-adv samples to ensure divisibility by the alignment unit. + + When ppo_mini_batch_size > 0, pads to dp_size * K (K = original mini-batch count) + so sequences distribute evenly across DP groups and mini-batches. Otherwise + pads to dp_size only. + + When all samples have zero advantage, keeps alignment-unit shortest samples so the + optimizer/LR-scheduler still steps (gradients will be ~0). + + Args: + batch: Full training batch with "advantages", "response_mask", "attention_mask". + dp_size: Data parallel size for alignment. + ppo_mini_batch_size: Mini-batch size for K computation. 0 = pad to dp_size only. + + Returns: + (filtered_batch, metrics): filtered_batch always has ≥ dp_size samples. + """ + response_mask = batch.batch[KEY_RESPONSE_MASK] + max_abs_adv = (batch.batch[KEY_ADVANTAGES].abs() * response_mask).max(dim=-1).values + num_total = max_abs_adv.numel() + bs_per_dp = ceildiv(num_total, dp_size) + + _nonzero_mask = max_abs_adv >= ZERO_ADV_EPS + nonzero_indices = torch.where(_nonzero_mask)[0].tolist() + num_nonzero = len(nonzero_indices) + + zero_idx_tensor = torch.where(~_nonzero_mask)[0] + num_zeros = zero_idx_tensor.numel() + + # Alignment unit: dp_size * K when distributing evenly across mini-batches, + # otherwise dp_size only. Capped by num_nonzero to ensure each mini-batch + # gets at least one nonzero sample per DP group. + if ppo_mini_batch_size > 0: + k_original = ceildiv(bs_per_dp, ppo_mini_batch_size) + align_opt_steps = min(k_original, max(1, ceildiv(num_nonzero, dp_size))) + else: + align_opt_steps = 1 + align = dp_size * align_opt_steps + + original_num_tokens = response_mask.sum().item() + if original_num_tokens == 0: + # Empty batch: skip filtering. + selected = None + elif num_nonzero == 0: + # All zero-adv: keep align shortest for LR schedule continuity (~0 gradient). + selected = _select_shortest(batch, zero_idx_tensor, align) + else: + num_pad = (-num_nonzero) % align + if num_zeros <= num_pad: + # Not enough zero-adv samples to align — skip filtering, use full batch. + selected = None + elif num_pad > 0: + selected = nonzero_indices + _select_shortest(batch, zero_idx_tensor, num_pad) + else: + selected = nonzero_indices + + if selected is None: + num_selected = num_total + filtered_batch = batch + else: + num_selected = len(selected) + assert num_selected != num_total, f"Filtering was a no-op but selected is not None: {num_selected=}" + + filtered_batch = batch[selected] + filtered_batch.meta_info.update( + { + KEY_ORIGINAL_BATCH_SIZE_PER_DP_GROUP: bs_per_dp, # per-GPU (matches ppo_mini_batch_size) + # Loss normalization corrections: agg_loss divides by local token/seq count, + # but we need to normalize by the original (pre-filter) counts so the + # gradient magnitude matches the unfiltered baseline. + KEY_NUM_TOKENS_CORRECTION_FACTOR: ( + filtered_batch.batch[KEY_RESPONSE_MASK].sum().item() / original_num_tokens + ), + KEY_NUM_SEQS_CORRECTION_FACTOR: num_selected / num_total, + } + ) + num_padded = num_selected - num_nonzero + + metrics = { + "actor/filter_zero_adv/num_nonzero": num_nonzero, + "actor/filter_zero_adv/num_padded": num_padded, + "actor/filter_zero_adv/num_kept": num_selected, + "actor/filter_zero_adv/num_total": num_total, + "actor/filter_zero_adv/kept_ratio": num_selected / num_total if num_total > 0 else 0.0, + } + return filtered_batch, metrics + @deprecated("verl.utils.metric.reduce_metrics") def reduce_metrics(metrics: dict[str, list[Any]]) -> dict[str, Any]: @@ -105,13 +234,13 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> dict[str, sequence_score = batch.batch["token_level_scores"].sum(-1) sequence_reward = batch.batch["token_level_rewards"].sum(-1) - advantages = batch.batch["advantages"] + advantages = batch.batch[KEY_ADVANTAGES] returns = batch.batch["returns"] max_response_length = batch.batch["responses"].shape[-1] prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool() - response_mask = batch.batch["response_mask"].bool() + response_mask = batch.batch[KEY_RESPONSE_MASK].bool() max_prompt_length = prompt_mask.size(-1) @@ -136,6 +265,10 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> dict[str, valid_adv = torch.masked_select(advantages, response_mask) valid_returns = torch.masked_select(returns, response_mask) + # Per-response zero-advantage ratio: responses whose advantage is zero contribute no policy gradient. + max_abs_adv = (advantages.abs() * response_mask).max(dim=-1).values # (bs,) + num_zero_adv = (max_abs_adv < ZERO_ADV_EPS).sum().item() + num_responses = max_abs_adv.numel() if use_critic: values = batch.batch["values"] valid_values = torch.masked_select(values, response_mask) @@ -170,6 +303,8 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> dict[str, "critic/advantages/mean": torch.mean(valid_adv).detach().item(), "critic/advantages/max": torch.max(valid_adv).detach().item(), "critic/advantages/min": torch.min(valid_adv).detach().item(), + "critic/advantages/zero_adv_count": num_zero_adv, + "critic/advantages/zero_adv_ratio": num_zero_adv / num_responses if num_responses > 0 else 0.0, # returns "critic/returns/mean": torch.mean(valid_returns).detach().item(), "critic/returns/max": torch.max(valid_returns).detach().item(), @@ -344,7 +479,7 @@ def compute_variance_proxy_metrics(batch: DataProto, gradient_norm: float = None # Note: IS weight statistics and mismatch metrics are logged in ray_trainer.py # Get scalar advantages (mean over timesteps) - advantages = batch.batch["advantages"] + advantages = batch.batch[KEY_ADVANTAGES] # Compute mean advantage per trajectory using masked_mean advantages_scalar = verl_F.masked_mean(advantages, response_mask, axis=-1) diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 6cb58dc6251..31dd3e338f3 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -44,10 +44,13 @@ from verl.trainer.ppo import core_algos from verl.trainer.ppo.core_algos import AdvantageEstimator, agg_loss from verl.trainer.ppo.metric_utils import ( + KEY_FILTER_ZERO_ADV_CONFIG, compute_data_metrics, compute_throughout_metrics, compute_timing_metrics, compute_variance_proxy_metrics, + filter_zero_adv_batch, + maybe_add_corrected_mfu, process_validation_metrics, ) from verl.trainer.ppo.reward import extract_reward @@ -1554,6 +1557,20 @@ def fit(self): config=self.config.algorithm, ) + # Filter zero-advantage responses to skip wasted actor compute. + # Responses in all-same-reward groups have advantage≈0 and contribute no policy gradient. + # Keep the unfiltered batch for critic update and metrics; use filtered batch for actor. + actor_batch = batch + if self.config.algorithm.filter_zero_adv.enable: + dp_size = self._get_dp_size(self.actor_rollout_wg, "actor") + actor_batch, filter_metrics = filter_zero_adv_batch( + batch, + dp_size, + ppo_mini_batch_size=self.config.actor_rollout_ref.actor.ppo_mini_batch_size, + ) + actor_batch.meta_info[KEY_FILTER_ZERO_ADV_CONFIG] = self.config.algorithm.filter_zero_adv + metrics.update(filter_metrics) + # update critic if self.use_critic: with marked_timer("update_critic", timing_raw, color="pink"): @@ -1565,7 +1582,7 @@ def fit(self): if self.config.trainer.critic_warmup <= self.global_steps: # update actor with marked_timer("update_actor", timing_raw, color="red"): - actor_output = self._update_actor(batch) + actor_output = self._update_actor(actor_batch) # Check if the ESI (Elastic Server Instance)/training plan is close to expiration. esi_close_to_expiration = should_save_ckpt_esi( @@ -1594,6 +1611,7 @@ def fit(self): self.checkpoint_manager.update_weights(self.global_steps) actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) + maybe_add_corrected_mfu(actor_output_metrics, actor_batch.meta_info) metrics.update(actor_output_metrics) # Log rollout generations if enabled diff --git a/verl/utils/config.py b/verl/utils/config.py index ccac5f1764f..4bf547807aa 100644 --- a/verl/utils/config.py +++ b/verl/utils/config.py @@ -169,6 +169,24 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str): if config.algorithm.use_kl_in_reward and config.actor_rollout_ref.actor.use_kl_loss: print("NOTICE: You have both enabled in-reward kl and kl loss.") + if config.algorithm.filter_zero_adv.enable and ( + config.actor_rollout_ref.actor.use_kl_loss and config.actor_rollout_ref.actor.kl_loss_coef != 0 + ): + raise ValueError( + "algorithm.filter_zero_adv and actor KL loss (use_kl_loss=True, kl_loss_coef != 0)" + " cannot both be enabled — zero-adv samples still contribute to KL loss." + ) + if config.algorithm.filter_zero_adv.enable and config.actor_rollout_ref.actor.entropy_coeff != 0: + raise ValueError( + "algorithm.filter_zero_adv and actor.entropy_coeff != 0 cannot both be True" + " — zero-adv samples still contribute non-zero entropy gradient." + ) + if config.algorithm.filter_zero_adv.enable and config.actor_rollout_ref.actor.calculate_entropy: + print( + "WARNING: filter_zero_adv with calculate_entropy=True — actor/entropy metric" + " is computed on the filtered batch only, not the full batch." + ) + # critic if use_critic: critic_config = omega_conf_to_dataclass(config.critic) diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index 23b3c6d4785..bb80c07204a 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -27,7 +27,14 @@ import verl.utils.torch_functional as verl_F from verl import DataProto -from verl.trainer.ppo.core_algos import agg_loss, get_policy_loss_fn, kl_penalty +from verl.trainer.ppo.core_algos import LOSS_AGG_TOKEN_MEAN, agg_loss, get_policy_loss_fn, kl_penalty +from verl.trainer.ppo.metric_utils import ( + KEY_FILTER_ZERO_ADV_CONFIG, + KEY_NUM_SEQS_CORRECTION_FACTOR, + KEY_NUM_TOKENS_CORRECTION_FACTOR, + KEY_ORIGINAL_BATCH_SIZE_PER_DP_GROUP, + ceildiv, +) from verl.utils.attention_utils import index_first_axis, pad_input, rearrange, unpad_input from verl.utils.device import get_device_id, get_device_name from verl.utils.fsdp_utils import FSDPModule, fsdp2_clip_grad_norm_ @@ -44,6 +51,50 @@ __all__ = ["DataParallelPPOActor"] logger = logging.getLogger(__file__) + + +@GPUMemoryLogger(role="dp actor", logger=logger) +def _split_filter_zero_adv_mini_batches( + data: DataProto, + ppo_mini_batch_size: int, +) -> tuple[list[DataProto], bool, bool, int, dict]: + """Split data into mini-batches, accounting for filter_zero_adv. + + Returns: + (mini_batches, filter_zero_adv, match_loss_curve, num_ghost_opt_steps, metrics) + """ + filter_zero_adv_config = data.meta_info.get(KEY_FILTER_ZERO_ADV_CONFIG, None) + _filter_zero_adv = filter_zero_adv_config is not None and getattr(filter_zero_adv_config, "enable", False) + + # When filtering is a no-op (nothing removed), KEY_ORIGINAL_BATCH_SIZE_PER_DP_GROUP + # is not set — treat as if filter_zero_adv is off. + filter_zero_adv = _filter_zero_adv and KEY_ORIGINAL_BATCH_SIZE_PER_DP_GROUP in data.meta_info + match_loss_curve = filter_zero_adv and getattr(filter_zero_adv_config, "match_loss_curve", False) + + # Original per-DP-group batch size (before filter_zero_adv), for K computation. + original_bs = data.meta_info.get(KEY_ORIGINAL_BATCH_SIZE_PER_DP_GROUP, len(data)) + + if match_loss_curve: + # Distribute filtered sequences evenly across K mini-batches + # (same K as baseline, capped by num_nonzero in filter_zero_adv_batch). + # Padding ensures divisibility by dp_size * K. + k_original = ceildiv(original_bs, ppo_mini_batch_size) + even_mini_batch_size = max(1, ceildiv(len(data), k_original)) + mini_batches = data.split(even_mini_batch_size) + else: + mini_batches = data.split(ppo_mini_batch_size) + + metrics = {"actor/num_mini_batches": len(mini_batches)} + num_ghost_opt_steps = 0 + if _filter_zero_adv: + # How many fewer opt steps vs baseline (0 when match_loss_curve preserves K). + k_baseline = ceildiv(original_bs, ppo_mini_batch_size) + num_ghost_opt_steps = k_baseline - len(mini_batches) + metrics["actor/num_ghost_mini_batches"] = num_ghost_opt_steps + + return mini_batches, filter_zero_adv, match_loss_curve, num_ghost_opt_steps, metrics + + logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) @@ -549,13 +600,16 @@ def update_policy(self, data: DataProto): # Split to make minibatch iterator for updating the actor # See PPO paper for details. https://arxiv.org/abs/1707.06347 - mini_batches = data.split(self.config.ppo_mini_batch_size) + mini_batches, _, match_loss_curve, num_ghost_opt_steps, split_metrics = _split_filter_zero_adv_mini_batches( + data, self.config.ppo_mini_batch_size + ) on_policy = len(mini_batches) == 1 and self.config.ppo_epochs == 1 metrics = { "actor/pg_loss": 0.0, "actor/kl_loss": 0.0, + **split_metrics, } for _ in range(self.config.ppo_epochs): for batch_idx, mini_batch in enumerate(mini_batches): @@ -565,14 +619,21 @@ def update_policy(self, data: DataProto): mini_batch, max_token_len=max_token_len, dp_group=torch.distributed.group.WORLD ) else: - self.gradient_accumulation = ( - self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu - ) micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu) + if match_loss_curve: + # Use config GA so each sample weighs 1/ppo_mini_batch_size, + # matching baseline per-sample gradient. + self.gradient_accumulation = ( + self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu + ) + else: + # Use actual micro-batch count: naturally handles partial last + # mini-batch in the fewer-K path without extra correction. + self.gradient_accumulation = len(micro_batches) self.actor_optimizer.zero_grad() - for micro_batch in micro_batches: + for micro_batch_idx, micro_batch in enumerate(micro_batches): micro_batch = micro_batch.to(get_device_id()) micro_batch_metrics = {} model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch, "pad_token_id": pad_token_id} @@ -585,10 +646,29 @@ def update_policy(self, data: DataProto): calculate_entropy = self.config.calculate_entropy or (entropy_coeff != 0) + # Weight each micro-batch so every sequence contributes + # 1/mini_bs to the gradient regardless of micro-batch size. if self.config.use_dynamic_bsz: loss_scale_factor = response_mask.shape[0] / self.config.ppo_mini_batch_size else: loss_scale_factor = 1 / self.gradient_accumulation + # match_loss_curve: each sample weighs 1/ppo_mini_batch_size + # (config) via loss_scale_factor, matching baseline. + # fewer-K: GA = len(micro_batches) naturally handles partial + # last mini-batch without extra correction. + + # Token-density correction for filter_zero_adv with token_mean: + # loss_scale_factor already corrects for sample count, but token_mean + # normalizes by total tokens. Removing zero-adv samples (often long + # all-wrong responses) changes the avg tokens/seq, inflating the + # per-sample gradient. Correct by the token-density ratio + # (tokens_correction / seqs_correction) so the effective denominator + # matches the original batch's token density. + if match_loss_curve and loss_agg_mode == LOSS_AGG_TOKEN_MEAN: + token_corr = data.meta_info.get(KEY_NUM_TOKENS_CORRECTION_FACTOR, 1.0) + seq_corr = data.meta_info.get(KEY_NUM_SEQS_CORRECTION_FACTOR, 1.0) + if seq_corr > 0: + loss_scale_factor *= token_corr / seq_corr # all return: (bsz, response_length) outputs = self._forward_micro_batch( @@ -662,11 +742,8 @@ def update_policy(self, data: DataProto): metrics["actor/kl_loss"] += kl_loss.detach().item() * loss_scale_factor micro_batch_metrics["actor/kl_coef"] = self.config.kl_loss_coef - if self.config.use_dynamic_bsz: - # relative to the dynamic bsz - loss = policy_loss * loss_scale_factor - else: - loss = policy_loss * loss_scale_factor + loss = policy_loss * loss_scale_factor + if self.scaler is not None: self.scaler.scale(loss).backward() else: @@ -678,5 +755,15 @@ def update_policy(self, data: DataProto): grad_norm = self._optimizer_step() mini_batch_metrics = {"actor/grad_norm": grad_norm.detach().item()} append_to_dict(metrics, mini_batch_metrics) + + # Ghost optimizer.step() with zero gradients to maintain K + # (the original number of optimizer steps per epoch). + # With match_loss_curve's even distribution, this is typically 0. + if match_loss_curve: + for _ in range(max(0, num_ghost_opt_steps)): + self.actor_optimizer.zero_grad() + grad_norm = self._optimizer_step() + append_to_dict(metrics, {"actor/grad_norm": grad_norm.detach().item()}) + self.actor_optimizer.zero_grad() return metrics