Skip to content

Commit fd873eb

Browse files
committed
add streaming diloco tests - upscale
Summary: - add a "barrier injector" to allow replicas to join at specific steps - added a test to validate streaming diloco works when new node joins without any failures - remove some code duplication Test Plan: ``` pytest -vs ./torchft/local_sgd_integ_test.py ```
1 parent eaa005a commit fd873eb

File tree

2 files changed

+165
-36
lines changed

2 files changed

+165
-36
lines changed

torchft/local_sgd_integ_test.py

Lines changed: 137 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,26 @@
33
import os
44
import re
55
import sys
6+
import threading
67
import traceback
78
from concurrent.futures import ThreadPoolExecutor, as_completed
89
from contextlib import ExitStack
10+
from dataclasses import field
911
from datetime import timedelta
1012
from typing import Any, Dict
1113
from unittest import TestCase, skipIf
1214

1315
import torch
1416
from parameterized import parameterized
1517
from torch import nn, optim
18+
from torch.distributed.pipelining import SplitPoint, pipeline
1619
from torch.distributed.tensor import DTensor, Replicate
1720

1821
from torchft._torchft import LighthouseServer
1922
from torchft.device_mesh import ft_init_device_mesh
2023
from torchft.local_sgd import DiLoCo, LocalSGD
2124
from torchft.manager import Manager
22-
from torchft.manager_integ_test import FailureInjector, MyModel, Runner
25+
from torchft.manager_integ_test import BarrierInjector, FailureInjector, MyModel, Runner
2326
from torchft.process_group import ProcessGroupBabyNCCL, ProcessGroupGloo
2427

2528
logger: logging.Logger = logging.getLogger(__name__)
@@ -254,6 +257,10 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
254257
all_state_dicts[manager_curr_step] = copy.deepcopy(
255258
manager._manager_state_dict()
256259
)
260+
261+
if runner.barrier_injector is not None:
262+
runner.barrier_injector.check(manager_curr_step)
263+
257264
batch_size = 1
258265
inputs = m.get_rand_inputs(batch_size, device=device)
259266
labels = m.get_rand_labels(batch_size, device=device)
@@ -276,6 +283,26 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
276283
return {}
277284

278285

286+
def assert_equal_global_state(
287+
rep0: dict[str, dict[str, dict[str, dict[str, object]]]],
288+
rep1: dict[str, dict[str, dict[str, dict[str, object]]]],
289+
) -> None:
290+
"""
291+
Asserts that the global state of the two replicas are equal
292+
"""
293+
for step in rep0.keys():
294+
torch.testing.assert_close(
295+
rep1[step]["user"]["default"]["original_params"],
296+
rep0[step]["user"]["default"]["original_params"],
297+
check_device=False,
298+
)
299+
torch.testing.assert_close(
300+
rep1[step]["user"]["default"]["outer_optim"],
301+
rep0[step]["user"]["default"]["outer_optim"],
302+
check_device=False,
303+
)
304+
305+
279306
class LocalSGDIntegTest(TestCase):
280307
# TODO: race condition due to using NCCL in threads causes manager allreduce to sometimes not be correct
281308
# Because of that the test is disabled for cuda
@@ -447,6 +474,9 @@ def test_diloco_recovery(self, use_cuda: bool) -> None:
447474
state_dicts = []
448475

449476
for fut in as_completed(futures):
477+
continue
478+
479+
for fut in futures:
450480
try:
451481
state_dicts.append(fut.result()[0])
452482
except Exception as e:
@@ -457,33 +487,23 @@ def test_diloco_recovery(self, use_cuda: bool) -> None:
457487

458488
rep0, rep1 = state_dicts
459489

460-
for step in rep0.keys():
461-
# Inner optimizer and local model parameters will be different e.g.
462-
# with 2 replicas r1 and r2, we sync every 2 steps
463-
#
464-
# - Manager Step 1
465-
# - Step 1: r1 and r2 step
466-
# - Step 2: r1 and r2 step, sync the model, quorum succeeds
467-
# - Manager Step 2
468-
# - Step 1: r1 steps but r2 fails
469-
# - Step 2:
470-
# - r1 steps, sync fails because r2 is down
471-
# - r1 recovers r2 from the model state at this step
472-
# that is different from the model for r1 at the beginning
473-
# of step Manager Step 2
474-
#
475-
# Outer optimizer and global model should be the same
490+
# Inner optimizer and local model parameters will be different e.g.
491+
# with 2 replicas r1 and r2, we sync every 2 steps
492+
#
493+
# - Manager Step 1
494+
# - Step 1: r1 and r2 step
495+
# - Step 2: r1 and r2 step, sync the model, quorum succeeds
496+
# - Manager Step 2
497+
# - Step 1: r1 steps but r2 fails
498+
# - Step 2:
499+
# - r1 steps, sync fails because r2 is down
500+
# - r1 recovers r2 from the model state at this step
501+
# that is different from the model for r1 at the beginning
502+
# of step Manager Step 2
503+
#
504+
# Outer optimizer and global model should be the same
505+
assert_equal_global_state(rep1, rep0)
476506

477-
torch.testing.assert_close(
478-
rep1[step]["user"]["default"]["original_params"],
479-
rep0[step]["user"]["default"]["original_params"],
480-
check_device=False,
481-
)
482-
torch.testing.assert_close(
483-
rep1[step]["user"]["default"]["outer_optim"],
484-
rep0[step]["user"]["default"]["outer_optim"],
485-
check_device=False,
486-
)
487507
self.assertEqual(failure_injectors[1].count, 1)
488508

489509
# pyre-fixme[56]: Pyre was not able to infer the type of argument
@@ -552,6 +572,8 @@ def test_streaming_diloco_recovery(self, use_cuda: bool) -> None:
552572

553573
rep0, rep1 = state_dicts
554574

575+
assert_equal_global_state(rep1, rep0)
576+
555577
for step in rep1.keys():
556578
if step == 2:
557579
# Replica 0 should have reset its `local_step` after failure
@@ -562,14 +584,93 @@ def test_streaming_diloco_recovery(self, use_cuda: bool) -> None:
562584
rep0[step]["user"]["local_step"], rep1[step]["user"]["local_step"]
563585
)
564586

565-
torch.testing.assert_close(
566-
rep1[step]["user"]["default"]["original_params"],
567-
rep0[step]["user"]["default"]["original_params"],
568-
check_device=False,
587+
self.assertEqual(failure_injectors[1].count, 1)
588+
589+
CONFIG: list[tuple[bool, int, int]] = [
590+
(use_cuda, n_fragments, fragment_sync_delay)
591+
for use_cuda in [True, False]
592+
for n_fragments in [1, 2]
593+
for fragment_sync_delay in [0, 1]
594+
]
595+
596+
# pyre-fixme[56]: Pyre was not able to infer the type of argument
597+
@skipIf(sys.platform == "darwin", "not reliable on mac")
598+
@parameterized.expand(CONFIG)
599+
def test_streaming_diloco_upscale(
600+
self, use_cuda: bool, n_fragments: int, fragment_sync_delay: int
601+
) -> None:
602+
# Skip the test if use_cuda is True and there are not enough GPUs
603+
if use_cuda and torch.cuda.device_count() < 2:
604+
self.skipTest("Not enough GPUs for CUDA test")
605+
606+
lighthouse = LighthouseServer(
607+
bind="[::]:0",
608+
min_replicas=2,
609+
)
610+
num_replicas = 3
611+
futures = []
612+
executors = []
613+
614+
barrier = threading.Barrier(num_replicas)
615+
616+
barrier_injectors = [
617+
# Make this replica join after other replicas have made 2 steps
618+
BarrierInjector().barrier_at(0, barrier),
619+
BarrierInjector().barrier_at(2, barrier),
620+
BarrierInjector().barrier_at(2, barrier),
621+
]
622+
623+
torch.manual_seed(42)
624+
# Initialize the model so we can pass in the state_dict
625+
m: nn.Module = MultiMyModel(2, 3, n_fragments)
626+
627+
for replica_id, barrier_injector in zip(range(num_replicas), barrier_injectors):
628+
executor = ThreadPoolExecutor(max_workers=1)
629+
executors.append(executor)
630+
runner = Runner(
631+
replica_id=replica_id,
632+
num_replicas=num_replicas,
633+
lighthouse_address=lighthouse.address(),
634+
failure_injector=FailureInjector(),
635+
barrier_injector=barrier_injector,
636+
train_loop=diloco_train_loop,
637+
train_loop_args={
638+
"model_state_dict": m.state_dict(),
639+
"n_fragments": n_fragments,
640+
"diloco_args": {
641+
"fragment_sync_delay": fragment_sync_delay,
642+
"sync_every": 4,
643+
},
644+
},
569645
)
570-
torch.testing.assert_close(
571-
rep1[step]["user"]["default"]["outer_optim"],
572-
rep0[step]["user"]["default"]["outer_optim"],
573-
check_device=False,
646+
futures.append(executor.submit(runner.run_replica))
647+
648+
state_dicts = []
649+
650+
for fut in as_completed(futures):
651+
continue
652+
653+
for fut in futures:
654+
try:
655+
state_dicts.append(fut.result()[0])
656+
except Exception as e:
657+
print(e)
658+
raise
659+
660+
lighthouse.shutdown()
661+
662+
rep0, rep1, rep2 = state_dicts
663+
664+
assert_equal_global_state(rep0, rep1)
665+
assert_equal_global_state(rep0, rep2)
666+
667+
for step in rep0.keys():
668+
self.assertEqual(
669+
rep0[step]["user"]["local_step"], rep1[step]["user"]["local_step"]
574670
)
575-
self.assertEqual(failure_injectors[1].count, 1)
671+
self.assertEqual(
672+
rep1[step]["user"]["local_step"], rep2[step]["user"]["local_step"]
673+
)
674+
675+
for barrier_injector in barrier_injectors:
676+
self.assertEqual(barrier_injector.count, 1)

torchft/manager_integ_test.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,30 @@ def check(self, rank: int, step: int) -> None:
8383
raise InjectedFailure(f"injected failure {rank=} {step=}")
8484

8585

86+
class BarrierInjector:
87+
"""
88+
Used to wait for all ranks and replicas to reach a certain step before continuing.
89+
Users need to make sure the size of the barrier is appropriately set.
90+
"""
91+
92+
def __init__(self) -> None:
93+
self._lock = threading.Lock()
94+
self._barriers: Dict[int, threading.Barrier] = dict()
95+
self.count = 0
96+
97+
def barrier_at(self, step: int, barrier: threading.Barrier) -> "BarrierInjector":
98+
with self._lock:
99+
self._barriers[step] = barrier
100+
return self
101+
102+
def check(self, step: int) -> None:
103+
with self._lock:
104+
if step in self._barriers:
105+
self.count += 1
106+
self._barriers[step].wait()
107+
self._barriers.pop(step)
108+
109+
86110
# R for an arbitrary return type
87111
R = TypeVar("R", covariant=True)
88112

@@ -106,6 +130,7 @@ class Runner:
106130
failure_injector: FailureInjector
107131
train_loop: TrainLoop[object]
108132

133+
barrier_injector: Optional[BarrierInjector] = None
109134
use_cuda: bool = False
110135
world_size: int = 1
111136
attempts: int = 3
@@ -223,6 +248,9 @@ def state_dict() -> Dict[str, Dict[str, object]]:
223248
criterion = nn.CrossEntropyLoss()
224249

225250
while True:
251+
if runner.barrier_injector is not None:
252+
runner.barrier_injector.check(manager.current_step())
253+
226254
inputs = torch.rand(2, 3)
227255
labels = torch.randint(4, (2,))
228256

0 commit comments

Comments
 (0)