diff --git a/areal/infra/staleness_manager.py b/areal/infra/staleness_manager.py index 89a782dee..9be01ab5c 100644 --- a/areal/infra/staleness_manager.py +++ b/areal/infra/staleness_manager.py @@ -112,6 +112,24 @@ def get_capacity(self) -> int: capacity = min(concurrency_capacity, staleness_capacity) return capacity + def on_version_recovered(self, version: int) -> None: + """Adjust accepted count after checkpoint recovery. + + When a checkpoint is recovered, the version jumps from 0 to the + recovered value. Without adjusting accepted, the capacity formula + yields (max_staleness + version + 1) * batch_size instead of the + intended (max_staleness + 1) * batch_size, causing a burst of + submissions and unbounded staleness growth. + + Expected to be called during trainer init, before any rollouts are + submitted, so running == 0. If running > 0 (unlikely in practice), + accepted is still set correctly and the capacity formula remains + bounded — (max_staleness + 1) * consumer_bs - running. + """ + with self.lock: + consumer_bs = max(1, self.consumer_batch_size) + self.rollout_stat.accepted = version * consumer_bs + def on_rollout_enqueued(self) -> None: """Callback when a rollout is enqueued as a pending input task. diff --git a/areal/trainer/rl_trainer.py b/areal/trainer/rl_trainer.py index 083e352e2..c7249eb28 100644 --- a/areal/trainer/rl_trainer.py +++ b/areal/trainer/rl_trainer.py @@ -368,6 +368,17 @@ def __init__( weight_update_meta=self.weight_update_meta, ) + # After recovery, sync the staleness manager so its capacity formula + # stays bounded despite the version jumping from 0 to recovery_version. + if self.recover_info is not None: + recovery_version = self.recover_info.last_step_info.global_step + 1 + if is_single_controller(): + sm = self.rollout.staleness_manager + else: + sm = self.rollout.workflow_executor.staleness_manager + if sm is not None: + sm.on_version_recovered(recovery_version) + self._config_perf_tracer() self._apply_initial_offload_policy() diff --git a/tests/test_staleness_manager.py b/tests/test_staleness_manager.py index 67d6a2bb6..cd8c5ffc3 100644 --- a/tests/test_staleness_manager.py +++ b/tests/test_staleness_manager.py @@ -766,6 +766,52 @@ def test_parametrized_version_progression(version): assert capacity == min(1000, expected_staleness_capacity) +@pytest.mark.parametrize("recovered_version", [0, 5, 10, 50]) +def test_on_version_recovered(recovered_version): + """Test that on_version_recovered adjusts accepted so capacity stays bounded.""" + version_provider = MockVersionProvider(0) + manager = StalenessManager( + version_provider=version_provider, + max_concurrent_rollouts=1000, + consumer_batch_size=16, + max_staleness=2, + ) + + # Simulate recovery: version jumps to recovered_version + version_provider.set_version(recovered_version) + manager.on_version_recovered(recovered_version) + + # After recovery, capacity should be (max_staleness + 1) * consumer_batch_size + # regardless of the recovered version value. + capacity = manager.get_capacity() + assert capacity == (2 + 1) * 16 + + +@pytest.mark.parametrize("running", [1, 5, 16]) +def test_on_version_recovered_with_running_rollouts(running): + """Test that on_version_recovered sets accepted correctly even when running > 0.""" + recovered_version = 10 + version_provider = MockVersionProvider(0) + manager = StalenessManager( + version_provider=version_provider, + max_concurrent_rollouts=1000, + consumer_batch_size=16, + max_staleness=2, + ) + + # Simulate in-flight rollouts at recovery time + for _ in range(running): + manager.on_rollout_enqueued() + manager.on_rollout_submitted() + + version_provider.set_version(recovered_version) + manager.on_version_recovered(recovered_version) + + # Capacity formula: (max_staleness + 1) * consumer_bs - running + capacity = manager.get_capacity() + assert capacity == (2 + 1) * 16 - running + + if __name__ == "__main__": # Run tests with pytest pytest.main([__file__, "-v"])