Skip to content

Commit e5dfb43

Browse files
authored
add streaming diloco test - recovery (#222)
Summary: - update diloco tests to work with model fragments - propogate arguments for training loop and optimizer wrapper through test runner - added a test that validates streaming diloco works when node crashes and rejoins - store the whole manager state dict for validation instead of only user's default state dict - fixed a bug that made the local step go out of sync Test Plan: ``` pytest -vs ./torchft/local_sgd_integ_test.py ```
1 parent cc6a70c commit e5dfb43

File tree

4 files changed

+182
-30
lines changed

4 files changed

+182
-30
lines changed

torchft/local_sgd.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,9 @@ def _step_post_hook(
680680
# Get the correct step when. This will continue after other committed.
681681
self._quorum_loop()
682682
self._should_recover = False
683+
# This is to be consistent with the nodes that are not recovering. They
684+
# proceed with the code below on the step after quorum completes.
685+
return
683686

684687
# We need to make sure all nodes send the same fragments in order.
685688
# This is to avoid deadlocking e.g.

torchft/local_sgd_integ_test.py

Lines changed: 161 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
import logging
33
import os
44
import re
5+
import sys
56
import traceback
67
from concurrent.futures import ThreadPoolExecutor, as_completed
78
from contextlib import ExitStack
89
from datetime import timedelta
9-
from sys import platform
1010
from typing import Any, Dict
11-
from unittest import TestCase
11+
from unittest import TestCase, skipIf
1212

1313
import torch
1414
from parameterized import parameterized
@@ -25,11 +25,40 @@
2525
logger: logging.Logger = logging.getLogger(__name__)
2626

2727

28+
class MultiMyModel(torch.nn.Module):
29+
def __init__(self, in_dim: int = 3, out_dim: int = 4, n_layers: int = 1) -> None:
30+
super().__init__()
31+
self.in_dim = in_dim
32+
33+
self.layers = torch.nn.ModuleList()
34+
for i in range(n_layers):
35+
self.layers.append(MyModel(in_dim, out_dim))
36+
in_dim, out_dim = out_dim, in_dim
37+
38+
self.out_dim = in_dim
39+
40+
def forward(self, x: torch.Tensor) -> torch.Tensor:
41+
for layer in self.layers:
42+
x = layer(x)
43+
return x
44+
45+
def get_rand_inputs(
46+
self, batch_size: int, device: torch.device = torch.device("cpu")
47+
) -> torch.Tensor:
48+
return torch.rand(batch_size, self.in_dim, device=device)
49+
50+
def get_rand_labels(
51+
self, batch_size: int, device: torch.device = torch.device("cpu")
52+
) -> torch.Tensor:
53+
return torch.randint(self.out_dim, (batch_size,), device=device)
54+
55+
2856
def local_sgd_train_loop(
2957
rank: int,
3058
store_port: int,
3159
device: torch.device,
3260
runner: Runner,
61+
train_loop_args: dict[str, Any] = {},
3362
) -> Dict[str, Dict[str, object]]:
3463
with ExitStack() as stack:
3564

@@ -99,11 +128,16 @@ def diloco_train_loop(
99128
store_port: int,
100129
device: torch.device,
101130
runner: Runner,
131+
train_loop_args: dict[str, Any] = {},
102132
) -> Dict[str, Dict[str, object]]:
133+
134+
model_state_dict = train_loop_args.get("model_state_dict", {})
135+
n_fragments = train_loop_args.get("n_fragments", 1)
136+
diloco_args = train_loop_args.get("diloco_args", {})
137+
103138
with ExitStack() as stack:
104139
# Declare the model and optimizers
105-
m: nn.Module = MyModel(2, 3)
106-
model_state_dict: Dict[str, Any] = runner.train_loop_args["model_state_dict"]
140+
m = MultiMyModel(2, 3, n_fragments)
107141
m.load_state_dict(model_state_dict)
108142
m.to(device)
109143

@@ -119,18 +153,26 @@ def diloco_train_loop(
119153
def load_state_dict(state_dict: Dict[str, Dict[str, object]]) -> None:
120154
m.load_state_dict(state_dict["model"])
121155
m.to(device)
122-
diloco._fragments[0].original_parameters = state_dict["original_params"]
123-
for name in diloco._fragments[0].original_parameters.keys():
124-
diloco._fragments[0].original_parameters[name] = (
125-
diloco._fragments[0].original_parameters[name].to(device)
126-
)
156+
157+
for i, fragment in enumerate(diloco._fragments):
158+
fragment.original_parameters = state_dict["original_params"][f"{i}"]
159+
160+
for fragment in diloco._fragments:
161+
for name in fragment.original_parameters.keys():
162+
fragment.original_parameters[name] = fragment.original_parameters[
163+
name
164+
].to(device)
165+
127166
inner_optimizer.load_state_dict(state_dict["inner_optim"])
128167
outer_optimizer.load_state_dict(state_dict["outer_optim"])
129168

130169
def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
131170
return {
132171
"model": m.state_dict(),
133-
"original_params": diloco._fragments[0].original_parameters,
172+
"original_params": {
173+
f"{i}": fragment.original_parameters
174+
for i, fragment in enumerate(diloco._fragments)
175+
},
134176
"inner_optim": inner_optimizer.state_dict(),
135177
"outer_optim": outer_optimizer.state_dict(),
136178
}
@@ -194,18 +236,24 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
194236

195237
criterion = nn.CrossEntropyLoss()
196238
all_state_dicts = {}
239+
240+
if "sync_every" not in diloco_args:
241+
diloco_args["sync_every"] = 2
242+
197243
with DiLoCo(
198244
manager,
199-
[m],
245+
[layer for layer in m.layers],
200246
inner_optimizer,
201247
outer_optimizer,
202248
backup_device=device,
203-
sync_every=2,
249+
**diloco_args,
204250
) as diloco:
205251
while True:
206252
manager_curr_step = manager.current_step()
207253
if manager_curr_step not in all_state_dicts:
208-
all_state_dicts[manager_curr_step] = copy.deepcopy(state_dict())
254+
all_state_dicts[manager_curr_step] = copy.deepcopy(
255+
manager._manager_state_dict()
256+
)
209257
batch_size = 1
210258
inputs = m.get_rand_inputs(batch_size, device=device)
211259
labels = m.get_rand_labels(batch_size, device=device)
@@ -308,7 +356,7 @@ def test_diloco_healthy(self, use_cuda: bool) -> None:
308356

309357
torch.manual_seed(42)
310358
# Initialize the model so we can pass in the state_dict
311-
m: nn.Module = MyModel(2, 3)
359+
m: nn.Module = MultiMyModel(2, 3, 1)
312360

313361
with ThreadPoolExecutor(max_workers=num_replicas) as executor:
314362
for replica_id in range(num_replicas):
@@ -341,16 +389,18 @@ def test_diloco_healthy(self, use_cuda: bool) -> None:
341389
for step, state_dict in rep1.items():
342390
# inner optimizer will be different, outer optimizer and model should be the same
343391
torch.testing.assert_close(
344-
state_dict["model"],
345-
rep0[step]["model"],
392+
state_dict["user"]["default"]["model"],
393+
rep0[step]["user"]["default"]["model"],
346394
check_device=False,
347395
)
348396
torch.testing.assert_close(
349-
state_dict["outer_optim"],
350-
rep0[step]["outer_optim"],
397+
state_dict["user"]["default"]["outer_optim"],
398+
rep0[step]["user"]["default"]["outer_optim"],
351399
check_device=False,
352400
)
353401

402+
# pyre-fixme[56]: Pyre was not able to infer the type of argument
403+
@skipIf(sys.platform == "darwin", "not reliable on mac")
354404
@parameterized.expand(
355405
[
356406
# (True,),
@@ -362,12 +412,6 @@ def test_diloco_recovery(self, use_cuda: bool) -> None:
362412
if use_cuda and torch.cuda.device_count() < 2:
363413
self.skipTest("Not enough GPUs for CUDA test")
364414

365-
if platform == "darwin":
366-
# TODO: This is likely because of Gloo not releasing GIL.
367-
# Fix in: https://github.com/pytorch/pytorch/pull/154976
368-
# Once this makes it to a stable package, we can re-enable this test.
369-
self.skipTest("Known issue in Gloo")
370-
371415
lighthouse = LighthouseServer(
372416
bind="[::]:0",
373417
min_replicas=2,
@@ -382,7 +426,7 @@ def test_diloco_recovery(self, use_cuda: bool) -> None:
382426

383427
torch.manual_seed(42)
384428
# Initialize the model so we can pass in the state_dict
385-
m: nn.Module = MyModel(2, 3)
429+
m: nn.Module = MultiMyModel(2, 3, 1)
386430

387431
with ThreadPoolExecutor(max_workers=num_replicas) as executor:
388432
for replica_id, failure_injector in zip(
@@ -431,13 +475,101 @@ def test_diloco_recovery(self, use_cuda: bool) -> None:
431475
# Outer optimizer and global model should be the same
432476

433477
torch.testing.assert_close(
434-
rep1[step]["original_params"],
435-
rep0[step]["original_params"],
478+
rep1[step]["user"]["default"]["original_params"],
479+
rep0[step]["user"]["default"]["original_params"],
480+
check_device=False,
481+
)
482+
torch.testing.assert_close(
483+
rep1[step]["user"]["default"]["outer_optim"],
484+
rep0[step]["user"]["default"]["outer_optim"],
485+
check_device=False,
486+
)
487+
self.assertEqual(failure_injectors[1].count, 1)
488+
489+
# pyre-fixme[56]: Pyre was not able to infer the type of argument
490+
@skipIf(sys.platform == "darwin", "not reliable on mac")
491+
@parameterized.expand(
492+
[
493+
# (True,),
494+
(False,),
495+
]
496+
)
497+
def test_streaming_diloco_recovery(self, use_cuda: bool) -> None:
498+
# Skip the test if use_cuda is True and there are not enough GPUs
499+
if use_cuda and torch.cuda.device_count() < 2:
500+
self.skipTest("Not enough GPUs for CUDA test")
501+
502+
lighthouse = LighthouseServer(
503+
bind="[::]:0",
504+
min_replicas=2,
505+
)
506+
num_replicas = 2
507+
futures = []
508+
509+
failure_injectors = [
510+
FailureInjector(),
511+
FailureInjector().fail_at(0, 2),
512+
]
513+
514+
torch.manual_seed(42)
515+
# Initialize the model so we can pass in the state_dict
516+
m: nn.Module = MultiMyModel(2, 3, 2)
517+
518+
with ThreadPoolExecutor(max_workers=num_replicas) as executor:
519+
for replica_id, failure_injector in zip(
520+
range(num_replicas), failure_injectors
521+
):
522+
runner = Runner(
523+
replica_id=replica_id,
524+
num_replicas=num_replicas,
525+
lighthouse_address=lighthouse.address(),
526+
failure_injector=failure_injector,
527+
train_loop=diloco_train_loop,
528+
train_loop_args={
529+
"model_state_dict": m.state_dict(),
530+
"n_fragments": 2,
531+
"diloco_args": {
532+
"fragment_sync_delay": 1,
533+
"sync_every": 4,
534+
},
535+
},
536+
)
537+
futures.append(executor.submit(runner.run_replica))
538+
539+
state_dicts = []
540+
541+
for fut in as_completed(futures):
542+
continue
543+
544+
for fut in futures:
545+
try:
546+
state_dicts.append(fut.result()[0])
547+
except Exception as e:
548+
print(e)
549+
raise
550+
551+
lighthouse.shutdown()
552+
553+
rep0, rep1 = state_dicts
554+
555+
for step in rep1.keys():
556+
if step == 2:
557+
# Replica 0 should have reset its `local_step` after failure
558+
self.assertEqual(rep1[step]["user"]["local_step"], 0)
559+
self.assertEqual(rep0[step]["user"]["local_step"], 5)
560+
else:
561+
self.assertEqual(
562+
rep0[step]["user"]["local_step"], rep1[step]["user"]["local_step"]
563+
)
564+
565+
torch.testing.assert_close(
566+
rep1[step]["user"]["default"]["original_params"],
567+
rep0[step]["user"]["default"]["original_params"],
436568
check_device=False,
437569
)
438570
torch.testing.assert_close(
439-
rep1[step]["outer_optim"],
440-
rep0[step]["outer_optim"],
571+
rep1[step]["user"]["default"]["outer_optim"],
572+
rep0[step]["user"]["default"]["outer_optim"],
441573
check_device=False,
442574
)
443575
self.assertEqual(failure_injectors[1].count, 1)

torchft/local_sgd_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,11 @@ def test_diloco_healthy(self) -> None:
157157
loss.backward()
158158
inner_optimizer.step()
159159

160+
self.assertEqual(diloco._local_step, 0)
161+
loss = model(inp).mean()
162+
loss.backward()
163+
inner_optimizer.step()
164+
160165
self.assertEqual(diloco._local_step, 1)
161166
self.assertEqual(manager.start_quorum.call_count, 1)
162167
loss = model(inp).mean()
@@ -217,6 +222,10 @@ def test_diloco_allreduce_call_efficiency(
217222
loss.backward()
218223
inner_optimizer.step()
219224

225+
loss = model(inp).mean()
226+
loss.backward()
227+
inner_optimizer.step()
228+
220229
allreduce_calls = manager.allreduce.call_count
221230
param_count = len([p for p in model.parameters() if p.requires_grad])
222231

torchft/manager_integ_test.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,12 @@ def check(self, rank: int, step: int) -> None:
8989

9090
class TrainLoop(Protocol[R]):
9191
def __call__(
92-
self, rank: int, store_port: int, device: torch.device, runner: "Runner"
92+
self,
93+
rank: int,
94+
store_port: int,
95+
device: torch.device,
96+
runner: "Runner",
97+
train_loop_args: dict[str, Any] = field(default_factory=dict),
9398
) -> R: ...
9499

95100

@@ -137,6 +142,7 @@ def _replica_main(self) -> List[object]:
137142
store_port=store.port,
138143
device=device,
139144
runner=self,
145+
train_loop_args=self.train_loop_args,
140146
)
141147
)
142148

@@ -170,6 +176,7 @@ def ddp_train_loop(
170176
store_port: int,
171177
device: torch.device,
172178
runner: Runner,
179+
train_loop_args: dict[str, Any] = {},
173180
) -> Dict[str, Dict[str, object]]:
174181
with ExitStack() as stack:
175182

@@ -525,6 +532,7 @@ def all_reduce_callback(
525532
store_port: int,
526533
device: torch.device,
527534
runner: Runner,
535+
train_loop_args: dict[str, Any] = {},
528536
) -> Optional[torch.Tensor]:
529537
with ExitStack() as stack:
530538
print(f"worker {runner.replica_id=} {rank=} {runner.world_size=} starting")

0 commit comments

Comments
 (0)