Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions areal/infra/staleness_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,19 @@ def get_capacity(self) -> int:
capacity = min(concurrency_capacity, staleness_capacity)
return capacity

def on_version_recovered(self, version: int) -> None:
"""Adjust accepted count after checkpoint recovery.

When a checkpoint is recovered, the version jumps from 0 to the
recovered value. Without adjusting accepted, the capacity formula
yields (max_staleness + version + 1) * batch_size instead of the
intended (max_staleness + 1) * batch_size, causing a burst of
submissions and unbounded staleness growth.
"""
with self.lock:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic is correct. Minor note: if running > 0 at recovery time (unlikely but theoretically possible), the resulting capacity would be (max_staleness + 1) * consumer_bs - running, which is still bounded — so this is safe. A brief assertion or comment noting that running is expected to be 0 at recovery time could help future readers.

consumer_bs = max(1, self.consumer_batch_size)
self.rollout_stat.accepted = version * consumer_bs

def on_rollout_enqueued(self) -> None:
"""Callback when a rollout is enqueued as a pending input task.

Expand Down
11 changes: 11 additions & 0 deletions areal/trainer/rl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,17 @@ def __init__(
weight_update_meta=self.weight_update_meta,
)

# After recovery, sync the staleness manager so its capacity formula
# stays bounded despite the version jumping from 0 to recovery_version.
if self.recover_info is not None:
recovery_version = self.recover_info.last_step_info.global_step + 1
if is_single_controller():
sm = self.rollout._staleness_manager
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RolloutController already exposes a public staleness_manager property (rollout_controller.py:1112-1114).

Suggestion:

sm = self.rollout.staleness_manager

Accessing _staleness_manager directly creates unnecessary coupling to the internal layout.

else:
sm = self.rollout.workflow_executor._staleness_manager
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This reaches through two layers of private attributes from the trainer. Consider adding a public staleness_manager property (or an on_version_recovered passthrough) to WorkflowExecutor, similar to how RolloutController already exposes one.

This would keep the SPMD path consistent with the single-controller path and avoid breaking if WorkflowExecutor internals change.

if sm is not None:
sm.on_version_recovered(recovery_version)

self._config_perf_tracer()
self._apply_initial_offload_policy()

Expand Down
21 changes: 21 additions & 0 deletions tests/test_staleness_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,27 @@ def test_parametrized_version_progression(version):
assert capacity == min(1000, expected_staleness_capacity)


@pytest.mark.parametrize("recovered_version", [5, 10, 50])
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Consider adding 0 to the parametrize list — recovery to version 0 should be a no-op, and it is a useful edge case to document.

@pytest.mark.parametrize("recovered_version", [0, 5, 10, 50])

Also worth adding a case where running > 0 at recovery time, to verify accepted is set correctly regardless of in-flight rollouts.

def test_on_version_recovered(recovered_version):
"""Test that on_version_recovered adjusts accepted so capacity stays bounded."""
version_provider = MockVersionProvider(0)
manager = StalenessManager(
version_provider=version_provider,
max_concurrent_rollouts=1000,
consumer_batch_size=16,
max_staleness=2,
)

# Simulate recovery: version jumps to recovered_version
version_provider.set_version(recovered_version)
manager.on_version_recovered(recovered_version)

# After recovery, capacity should be (max_staleness + 1) * consumer_batch_size
# regardless of the recovered version value.
capacity = manager.get_capacity()
assert capacity == (2 + 1) * 16


if __name__ == "__main__":
# Run tests with pytest
pytest.main([__file__, "-v"])
Loading