Conversation
…overy StalenessManager's accepted counter started at 0 while the version was restored to a high value by the recovery path. This caused the capacity formula to yield (max_staleness + recovered_version + 1) * batch_size instead of the intended (max_staleness + 1) * batch_size, allowing a burst of rollout submissions and unbounded staleness growth. Add on_version_recovered() to StalenessManager and call it from rl_trainer after recover completes. The trainer accesses the staleness manager directly via the known concrete type (RolloutController in single-controller mode, workflow_executor in SPMD mode).
fishcrap
left a comment
There was a problem hiding this comment.
Overall the fix is correct and well-motivated. Main suggestion is to use public accessors instead of reaching into private attributes. See inline comments.
| if self.recover_info is not None: | ||
| recovery_version = self.recover_info.last_step_info.global_step + 1 | ||
| if is_single_controller(): | ||
| sm = self.rollout._staleness_manager |
There was a problem hiding this comment.
RolloutController already exposes a public staleness_manager property (rollout_controller.py:1112-1114).
Suggestion:
sm = self.rollout.staleness_managerAccessing _staleness_manager directly creates unnecessary coupling to the internal layout.
| if is_single_controller(): | ||
| sm = self.rollout._staleness_manager | ||
| else: | ||
| sm = self.rollout.workflow_executor._staleness_manager |
There was a problem hiding this comment.
This reaches through two layers of private attributes from the trainer. Consider adding a public staleness_manager property (or an on_version_recovered passthrough) to WorkflowExecutor, similar to how RolloutController already exposes one.
This would keep the SPMD path consistent with the single-controller path and avoid breaking if WorkflowExecutor internals change.
| assert capacity == min(1000, expected_staleness_capacity) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("recovered_version", [5, 10, 50]) |
There was a problem hiding this comment.
nit: Consider adding 0 to the parametrize list — recovery to version 0 should be a no-op, and it is a useful edge case to document.
@pytest.mark.parametrize("recovered_version", [0, 5, 10, 50])Also worth adding a case where running > 0 at recovery time, to verify accepted is set correctly regardless of in-flight rollouts.
| intended (max_staleness + 1) * batch_size, causing a burst of | ||
| submissions and unbounded staleness growth. | ||
| """ | ||
| with self.lock: |
There was a problem hiding this comment.
The logic is correct. Minor note: if running > 0 at recovery time (unlikely but theoretically possible), the resulting capacity would be (max_staleness + 1) * consumer_bs - running, which is still bounded — so this is safe. A brief assertion or comment noting that running is expected to be 0 at recovery time could help future readers.
|
/gemini |
Description
Fix staleness capacity inflation after checkpoint recovery in async RL training.
When recovering from a checkpoint,
StalenessManager'sacceptedcounter started at 0 while the model version was restored to a high value. This caused the capacity formula to yield(max_staleness + recovered_version + 1) * batch_sizeinstead of the intended(max_staleness + 1) * batch_size, allowing a burst of rollout submissions and unbounded staleness growthRelated Issue
N/A
Type of Change
Changes
areal/infra/staleness_manager.py: Addon_version_recovered(version)method that adjusts theacceptedcounter toversion * consumer_batch_size, restoring correct capacity bounds.areal/trainer/rl_trainer.py: Callon_version_recovered()afterrecover_handler.load()completes. Accesses the staleness manager directly via the known concrete type (RolloutController._staleness_managerin single-controller mode,workflow_executor._staleness_managerin SPMD mode).tests/test_staleness_manager.py: Add parametrized testtest_on_version_recoveredvalidating capacity stays bounded regardless of recovered version.Checklist
pre-commit run --all-files)./docs/build_all.sh)main/review-prcommand/create-prAdditional Context
Root cause: In the staleness capacity formula
(max_staleness + current_version + 1) * consumer_bs - sample_cnt, after recovery the version jumps (e.g., 0 → 268) butsample_cnt(accepted + running) remains 0. This yields capacity =(2 + 268 + 1) * 4 = 1084on a single SPMD rank, instead of the intended(2 + 1) * 4 = 12, flooding the rollout queue with requests that become stale before consumption.Need help? Check the Contributing Guide or ask in
GitHub Discussions!