Skip to content

fix(infra): correct staleness capacity inflation after recovery#1345

Open
daihaowz wants to merge 1 commit into
mainfrom
fh/bugfix
Open

fix(infra): correct staleness capacity inflation after recovery#1345
daihaowz wants to merge 1 commit into
mainfrom
fh/bugfix

Conversation

@daihaowz
Copy link
Copy Markdown
Collaborator

Description

Fix staleness capacity inflation after checkpoint recovery in async RL training.

When recovering from a checkpoint, StalenessManager's accepted counter started at 0 while the model version was restored to a high value. This caused the capacity formula to yield (max_staleness + recovered_version + 1) * batch_size instead of the intended (max_staleness + 1) * batch_size, allowing a burst of rollout submissions and unbounded staleness growth

Related Issue

N/A

Type of Change

  • 🐛 Bug fix

Changes

  • areal/infra/staleness_manager.py: Add on_version_recovered(version) method that adjusts the accepted counter to version * consumer_batch_size, restoring correct capacity bounds.
  • areal/trainer/rl_trainer.py: Call on_version_recovered() after recover_handler.load() completes. Accesses the staleness manager directly via the known concrete type (RolloutController._staleness_manager in single-controller mode, workflow_executor._staleness_manager in SPMD mode).
  • tests/test_staleness_manager.py: Add parametrized test test_on_version_recovered validating capacity stays bounded regardless of recovered version.

Checklist

  • I have read the Contributing Guide
  • Pre-commit hooks pass (pre-commit run --all-files)
  • Relevant tests pass; new tests added for new functionality
  • Documentation updated (if applicable; built with ./docs/build_all.sh)
  • Branch is up to date with main
  • Self-reviewed via /review-pr command
  • This PR was created by a coding agent via /create-pr
  • This PR is a breaking change

Additional Context

Root cause: In the staleness capacity formula (max_staleness + current_version + 1) * consumer_bs - sample_cnt, after recovery the version jumps (e.g., 0 → 268) but sample_cnt (accepted + running) remains 0. This yields capacity = (2 + 268 + 1) * 4 = 1084 on a single SPMD rank, instead of the intended (2 + 1) * 4 = 12, flooding the rollout queue with requests that become stale before consumption.


Need help? Check the Contributing Guide or ask in
GitHub Discussions!

…overy

StalenessManager's accepted counter started at 0 while the version was
restored to a high value by the recovery path. This caused the capacity
formula to yield (max_staleness + recovered_version + 1) * batch_size
instead of the intended (max_staleness + 1) * batch_size, allowing a
burst of rollout submissions and unbounded staleness growth.

Add on_version_recovered() to StalenessManager and call it from
rl_trainer after recover completes. The trainer accesses the staleness
manager directly via the known concrete type (RolloutController in
single-controller mode, workflow_executor in SPMD mode).
Copy link
Copy Markdown
Collaborator

@fishcrap fishcrap left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/gemini

Copy link
Copy Markdown
Collaborator

@fishcrap fishcrap left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall the fix is correct and well-motivated. Main suggestion is to use public accessors instead of reaching into private attributes. See inline comments.

if self.recover_info is not None:
recovery_version = self.recover_info.last_step_info.global_step + 1
if is_single_controller():
sm = self.rollout._staleness_manager
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RolloutController already exposes a public staleness_manager property (rollout_controller.py:1112-1114).

Suggestion:

sm = self.rollout.staleness_manager

Accessing _staleness_manager directly creates unnecessary coupling to the internal layout.

if is_single_controller():
sm = self.rollout._staleness_manager
else:
sm = self.rollout.workflow_executor._staleness_manager
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This reaches through two layers of private attributes from the trainer. Consider adding a public staleness_manager property (or an on_version_recovered passthrough) to WorkflowExecutor, similar to how RolloutController already exposes one.

This would keep the SPMD path consistent with the single-controller path and avoid breaking if WorkflowExecutor internals change.

assert capacity == min(1000, expected_staleness_capacity)


@pytest.mark.parametrize("recovered_version", [5, 10, 50])
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Consider adding 0 to the parametrize list — recovery to version 0 should be a no-op, and it is a useful edge case to document.

@pytest.mark.parametrize("recovered_version", [0, 5, 10, 50])

Also worth adding a case where running > 0 at recovery time, to verify accepted is set correctly regardless of in-flight rollouts.

intended (max_staleness + 1) * batch_size, causing a burst of
submissions and unbounded staleness growth.
"""
with self.lock:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic is correct. Minor note: if running > 0 at recovery time (unlikely but theoretically possible), the resulting capacity would be (max_staleness + 1) * consumer_bs - running, which is still bounded — so this is safe. A brief assertion or comment noting that running is expected to be 0 at recovery time could help future readers.

@fishcrap
Copy link
Copy Markdown
Collaborator

/gemini

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants