From 7eedaf95e3085c25fdb96bb4f113db5633b05c92 Mon Sep 17 00:00:00 2001 From: fenghui Date: Sat, 16 May 2026 22:44:19 +0800 Subject: [PATCH 1/2] fix(infra): correct staleness capacity inflation after checkpoint recovery StalenessManager's accepted counter started at 0 while the version was restored to a high value by the recovery path. This caused the capacity formula to yield (max_staleness + recovered_version + 1) * batch_size instead of the intended (max_staleness + 1) * batch_size, allowing a burst of rollout submissions and unbounded staleness growth. Add on_version_recovered() to StalenessManager and call it from rl_trainer after recover completes. The trainer accesses the staleness manager directly via the known concrete type (RolloutController in single-controller mode, workflow_executor in SPMD mode). --- areal/infra/staleness_manager.py | 13 +++++++++++++ areal/trainer/rl_trainer.py | 11 +++++++++++ tests/test_staleness_manager.py | 21 +++++++++++++++++++++ 3 files changed, 45 insertions(+) diff --git a/areal/infra/staleness_manager.py b/areal/infra/staleness_manager.py index 89a782dee2..51c9f72ad4 100644 --- a/areal/infra/staleness_manager.py +++ b/areal/infra/staleness_manager.py @@ -112,6 +112,19 @@ def get_capacity(self) -> int: capacity = min(concurrency_capacity, staleness_capacity) return capacity + def on_version_recovered(self, version: int) -> None: + """Adjust accepted count after checkpoint recovery. + + When a checkpoint is recovered, the version jumps from 0 to the + recovered value. Without adjusting accepted, the capacity formula + yields (max_staleness + version + 1) * batch_size instead of the + intended (max_staleness + 1) * batch_size, causing a burst of + submissions and unbounded staleness growth. + """ + with self.lock: + consumer_bs = max(1, self.consumer_batch_size) + self.rollout_stat.accepted = version * consumer_bs + def on_rollout_enqueued(self) -> None: """Callback when a rollout is enqueued as a pending input task. diff --git a/areal/trainer/rl_trainer.py b/areal/trainer/rl_trainer.py index 083e352e27..08e5f7b78a 100644 --- a/areal/trainer/rl_trainer.py +++ b/areal/trainer/rl_trainer.py @@ -368,6 +368,17 @@ def __init__( weight_update_meta=self.weight_update_meta, ) + # After recovery, sync the staleness manager so its capacity formula + # stays bounded despite the version jumping from 0 to recovery_version. + if self.recover_info is not None: + recovery_version = self.recover_info.last_step_info.global_step + 1 + if is_single_controller(): + sm = self.rollout._staleness_manager + else: + sm = self.rollout.workflow_executor._staleness_manager + if sm is not None: + sm.on_version_recovered(recovery_version) + self._config_perf_tracer() self._apply_initial_offload_policy() diff --git a/tests/test_staleness_manager.py b/tests/test_staleness_manager.py index 67d6a2bb6c..e33b90aa82 100644 --- a/tests/test_staleness_manager.py +++ b/tests/test_staleness_manager.py @@ -766,6 +766,27 @@ def test_parametrized_version_progression(version): assert capacity == min(1000, expected_staleness_capacity) +@pytest.mark.parametrize("recovered_version", [5, 10, 50]) +def test_on_version_recovered(recovered_version): + """Test that on_version_recovered adjusts accepted so capacity stays bounded.""" + version_provider = MockVersionProvider(0) + manager = StalenessManager( + version_provider=version_provider, + max_concurrent_rollouts=1000, + consumer_batch_size=16, + max_staleness=2, + ) + + # Simulate recovery: version jumps to recovered_version + version_provider.set_version(recovered_version) + manager.on_version_recovered(recovered_version) + + # After recovery, capacity should be (max_staleness + 1) * consumer_batch_size + # regardless of the recovered version value. + capacity = manager.get_capacity() + assert capacity == (2 + 1) * 16 + + if __name__ == "__main__": # Run tests with pytest pytest.main([__file__, "-v"]) From 847c080754b3901ddd130142c91e9bde9afb621d Mon Sep 17 00:00:00 2001 From: daihao Date: Tue, 19 May 2026 11:06:44 +0800 Subject: [PATCH 2/2] fix(infra): clarify staleness recovery semantics and use public APIs Address review feedback on the staleness manager recovery path: - Document that on_version_recovered is expected to be called with running == 0 and explain the bound when it is not. - Reach the manager through the public staleness_manager properties on RolloutController and WorkflowExecutor instead of the private _staleness_manager attribute, avoiding coupling to internal layout. - Extend tests with the version=0 no-op case and a parametrized case with in-flight rollouts to verify accepted is set correctly. --- areal/infra/staleness_manager.py | 5 +++++ areal/trainer/rl_trainer.py | 4 ++-- tests/test_staleness_manager.py | 27 ++++++++++++++++++++++++++- 3 files changed, 33 insertions(+), 3 deletions(-) diff --git a/areal/infra/staleness_manager.py b/areal/infra/staleness_manager.py index 51c9f72ad4..9be01ab5c4 100644 --- a/areal/infra/staleness_manager.py +++ b/areal/infra/staleness_manager.py @@ -120,6 +120,11 @@ def on_version_recovered(self, version: int) -> None: yields (max_staleness + version + 1) * batch_size instead of the intended (max_staleness + 1) * batch_size, causing a burst of submissions and unbounded staleness growth. + + Expected to be called during trainer init, before any rollouts are + submitted, so running == 0. If running > 0 (unlikely in practice), + accepted is still set correctly and the capacity formula remains + bounded — (max_staleness + 1) * consumer_bs - running. """ with self.lock: consumer_bs = max(1, self.consumer_batch_size) diff --git a/areal/trainer/rl_trainer.py b/areal/trainer/rl_trainer.py index 08e5f7b78a..c7249eb281 100644 --- a/areal/trainer/rl_trainer.py +++ b/areal/trainer/rl_trainer.py @@ -373,9 +373,9 @@ def __init__( if self.recover_info is not None: recovery_version = self.recover_info.last_step_info.global_step + 1 if is_single_controller(): - sm = self.rollout._staleness_manager + sm = self.rollout.staleness_manager else: - sm = self.rollout.workflow_executor._staleness_manager + sm = self.rollout.workflow_executor.staleness_manager if sm is not None: sm.on_version_recovered(recovery_version) diff --git a/tests/test_staleness_manager.py b/tests/test_staleness_manager.py index e33b90aa82..cd8c5ffc33 100644 --- a/tests/test_staleness_manager.py +++ b/tests/test_staleness_manager.py @@ -766,7 +766,7 @@ def test_parametrized_version_progression(version): assert capacity == min(1000, expected_staleness_capacity) -@pytest.mark.parametrize("recovered_version", [5, 10, 50]) +@pytest.mark.parametrize("recovered_version", [0, 5, 10, 50]) def test_on_version_recovered(recovered_version): """Test that on_version_recovered adjusts accepted so capacity stays bounded.""" version_provider = MockVersionProvider(0) @@ -787,6 +787,31 @@ def test_on_version_recovered(recovered_version): assert capacity == (2 + 1) * 16 +@pytest.mark.parametrize("running", [1, 5, 16]) +def test_on_version_recovered_with_running_rollouts(running): + """Test that on_version_recovered sets accepted correctly even when running > 0.""" + recovered_version = 10 + version_provider = MockVersionProvider(0) + manager = StalenessManager( + version_provider=version_provider, + max_concurrent_rollouts=1000, + consumer_batch_size=16, + max_staleness=2, + ) + + # Simulate in-flight rollouts at recovery time + for _ in range(running): + manager.on_rollout_enqueued() + manager.on_rollout_submitted() + + version_provider.set_version(recovered_version) + manager.on_version_recovered(recovered_version) + + # Capacity formula: (max_staleness + 1) * consumer_bs - running + capacity = manager.get_capacity() + assert capacity == (2 + 1) * 16 - running + + if __name__ == "__main__": # Run tests with pytest pytest.main([__file__, "-v"])