Skip to content

Commit 705eaec

Browse files
committed
fix infinite recovery
Summary: - we don't increase the max_step when a node is catching up because we don't call should_commit - this can lead the node always being behind and get into an infinite recovery loop - note, this can result in the global parameters falling out of sync, the diff includes an RFC on how to fix that if we need to - document another case where `should_commit` can return `True` but it shouldn't because allreduce failed (this is also relvant only to the case when we can have pending inflight allreduce) - make an assert based on the fragment sync schedule to make sure we don't run into this Test Plan: - tested on a cluster of 3 nodes by removing and adding a node - the `max_step` and `local_step` increase in the manager's state dict after both failure and recovery metrics from the healthy node <img width="1103" alt="Screenshot 2025-06-15 at 10 53 28 PM copy" src="https://github.com/user-attachments/assets/8640780c-fd20-4266-aa3c-3116776a9c68" /> metrics from the failed and recovered node <img width="1101" alt="Screenshot 2025-06-15 at 10 56 49 PM copy" src="https://github.com/user-attachments/assets/cc2a1c57-715f-4e0a-8e00-7c62da525dc3" />
1 parent 9241a8b commit 705eaec

File tree

2 files changed

+61
-3
lines changed

2 files changed

+61
-3
lines changed

torchft/local_sgd.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,14 +356,47 @@ def perform_sync(self) -> bool:
356356
Overrides the sync method to wait for the scheduled allreduce to finish and
357357
steps using the outer optimizer.
358358
"""
359-
if len(self._allreduce_futures) == 0:
360-
return True
359+
# Waiting for an allreduce before it has been sent is currently not supported.
360+
# Please make sure to not do this to avoid running into inconsistencies.
361+
#
362+
# This can happen when using large values of `fragment_sync_delay`.
363+
# The node might not have participated in syncing of this fragment.
364+
#
365+
# The allreduce for other nodes who did might actually
366+
# succeed and in that case, we shouldn't allow recovery
367+
# from this node.
368+
#
369+
# We do need to increase the `max_step` here so we
370+
# don't end up in an infinite loop of needing to recover
371+
# but we can't let other nodes recover from this node
372+
# because it doesn't have the latest state.
373+
#
374+
# We can add a `is_catching_up` flag to the state_dict
375+
# to disallow recoveries from this node. Such nodes can
376+
# be excluded from `max_step` calculation unless all
377+
# nodes are catching up. This approach makes the replica state
378+
# of global parameters diverge though. So we could add recovery
379+
# for a particular fragment from a peer node as a part of the
380+
# `should_commit` or next `quorum` when a node is catching up.
381+
assert len(self._allreduce_futures) > 0
361382

362383
self.wait()
363384

364385
# Restore the parameters back to the previous state
365386
self.restore_parameters()
366387

388+
# For large values of `fragment_sync_delay`, this call can be
389+
# a problem.
390+
#
391+
# This can return success even if the allreduce failed. Because
392+
# the process group could have been reconfigured while the
393+
# allreduce was inflight. The inflight allreduce may or may
394+
# not have been aborted.
395+
#
396+
# We can track errors per allreduce to
397+
# let the commit fail here. But this has the downside of
398+
# reconfiguring the pg too many times resulting in
399+
# more aborts and more commit failures.
367400
should_commit = self._manager.should_commit()
368401

369402
if should_commit:
@@ -575,6 +608,13 @@ def __init__(
575608
for i, model_fragment in enumerate(model_fragments)
576609
]
577610

611+
# This is to make sure we adhere to the assumptions made by the
612+
# `_StreamingDiLoCoFragment` about the fragment sync schedule.
613+
assert fragment_sync_delay < sync_every // len(model_fragments)
614+
615+
# Used to ensure that we try to sync a fragment after we've sent a prepare for it
616+
self._first_prepare_sent: set[int] = set()
617+
578618
# Need to copy the parameters to the host to be safe if we are on the first step.
579619
self._save_parameters()
580620

@@ -618,6 +658,8 @@ def _wait(self) -> None:
618658
for fragment in self._fragments:
619659
fragment.wait()
620660

661+
self._first_prepare_sent.clear()
662+
621663
def _quorum_loop(self) -> None:
622664
"""
623665
Performs infinite retries until quorum is successfull
@@ -660,12 +702,18 @@ def _step_post_hook(
660702

661703
logger.debug(f"preparing fragment {i} at step {step}")
662704

705+
self._first_prepare_sent.add(i)
663706
fragment.prepare_sync()
664707

665708
for i, fragment in enumerate(self._fragments):
666709
if not fragment.should_sync_fragment(step):
667710
continue
668711

712+
# We need to have sent an allreduce before we can syncing
713+
# a fragment
714+
if i not in self._first_prepare_sent:
715+
continue
716+
669717
logger.debug(f"syncing fragment {i} at step {step}")
670718

671719
if not fragment.perform_sync():
@@ -708,6 +756,16 @@ def _step_post_hook(
708756
# waste after recovery
709757
self._quorum_loop()
710758

759+
# TODO: Since we do quorum after commit, there might be a big gap until
760+
# the next allreduce. This increases the chances of nodes failing
761+
# and so the allreduce to fail.
762+
# - We could maybe do a quorum again right before preparing for a fragment
763+
# using `shrink_only`. This might make it tricky for new nodes to join
764+
# though.
765+
# - Maintain a sequence number in the state dict that gets bumped at every
766+
# quorum call. Then we can do a quorum right before allreduce and avoid
767+
# doing quorums after commit.
768+
711769
# We need to set make sure `_local_step` is still
712770
# the same across all replicas if `quorum_id` changed.
713771
#

train_diloco.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def trace_handler(p):
201201
outer_optimizer,
202202
backup_device=device,
203203
sync_every=20 if USE_STREAMING else 20,
204-
fragment_sync_delay=10 if USE_STREAMING else 0,
204+
fragment_sync_delay=5 if USE_STREAMING else 0,
205205
should_quantize=False,
206206
) as diloco:
207207
while True:

0 commit comments

Comments
 (0)