2
2
import logging
3
3
import os
4
4
import re
5
+ import sys
5
6
import traceback
6
7
from concurrent .futures import ThreadPoolExecutor , as_completed
7
8
from contextlib import ExitStack
8
9
from datetime import timedelta
9
- from sys import platform
10
10
from typing import Any , Dict
11
- from unittest import TestCase
11
+ from unittest import TestCase , skipIf
12
12
13
13
import torch
14
14
from parameterized import parameterized
25
25
logger : logging .Logger = logging .getLogger (__name__ )
26
26
27
27
28
+ class MultiMyModel (torch .nn .Module ):
29
+ def __init__ (self , in_dim : int = 3 , out_dim : int = 4 , n_layers : int = 1 ) -> None :
30
+ super ().__init__ ()
31
+ self .in_dim = in_dim
32
+
33
+ self .layers = torch .nn .ModuleList ()
34
+ for i in range (n_layers ):
35
+ self .layers .append (MyModel (in_dim , out_dim ))
36
+ in_dim , out_dim = out_dim , in_dim
37
+
38
+ self .out_dim = in_dim
39
+
40
+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
41
+ for layer in self .layers :
42
+ x = layer (x )
43
+ return x
44
+
45
+ def get_rand_inputs (
46
+ self , batch_size : int , device : torch .device = torch .device ("cpu" )
47
+ ) -> torch .Tensor :
48
+ return torch .rand (batch_size , self .in_dim , device = device )
49
+
50
+ def get_rand_labels (
51
+ self , batch_size : int , device : torch .device = torch .device ("cpu" )
52
+ ) -> torch .Tensor :
53
+ return torch .randint (self .out_dim , (batch_size ,), device = device )
54
+
55
+
28
56
def local_sgd_train_loop (
29
57
rank : int ,
30
58
store_port : int ,
31
59
device : torch .device ,
32
60
runner : Runner ,
61
+ train_loop_args : dict [str , Any ] = {},
33
62
) -> Dict [str , Dict [str , object ]]:
34
63
with ExitStack () as stack :
35
64
@@ -99,11 +128,16 @@ def diloco_train_loop(
99
128
store_port : int ,
100
129
device : torch .device ,
101
130
runner : Runner ,
131
+ train_loop_args : dict [str , Any ] = {},
102
132
) -> Dict [str , Dict [str , object ]]:
133
+
134
+ model_state_dict = train_loop_args .get ("model_state_dict" , {})
135
+ n_fragments = train_loop_args .get ("n_fragments" , 1 )
136
+ diloco_args = train_loop_args .get ("diloco_args" , {})
137
+
103
138
with ExitStack () as stack :
104
139
# Declare the model and optimizers
105
- m : nn .Module = MyModel (2 , 3 )
106
- model_state_dict : Dict [str , Any ] = runner .train_loop_args ["model_state_dict" ]
140
+ m = MultiMyModel (2 , 3 , n_fragments )
107
141
m .load_state_dict (model_state_dict )
108
142
m .to (device )
109
143
@@ -119,18 +153,26 @@ def diloco_train_loop(
119
153
def load_state_dict (state_dict : Dict [str , Dict [str , object ]]) -> None :
120
154
m .load_state_dict (state_dict ["model" ])
121
155
m .to (device )
122
- diloco ._fragments [0 ].original_parameters = state_dict ["original_params" ]
123
- for name in diloco ._fragments [0 ].original_parameters .keys ():
124
- diloco ._fragments [0 ].original_parameters [name ] = (
125
- diloco ._fragments [0 ].original_parameters [name ].to (device )
126
- )
156
+
157
+ for i , fragment in enumerate (diloco ._fragments ):
158
+ fragment .original_parameters = state_dict ["original_params" ][f"{ i } " ]
159
+
160
+ for fragment in diloco ._fragments :
161
+ for name in fragment .original_parameters .keys ():
162
+ fragment .original_parameters [name ] = fragment .original_parameters [
163
+ name
164
+ ].to (device )
165
+
127
166
inner_optimizer .load_state_dict (state_dict ["inner_optim" ])
128
167
outer_optimizer .load_state_dict (state_dict ["outer_optim" ])
129
168
130
169
def state_dict () -> Dict [str , Dict [str , object ]]: # pyre-ignore[53]
131
170
return {
132
171
"model" : m .state_dict (),
133
- "original_params" : diloco ._fragments [0 ].original_parameters ,
172
+ "original_params" : {
173
+ f"{ i } " : fragment .original_parameters
174
+ for i , fragment in enumerate (diloco ._fragments )
175
+ },
134
176
"inner_optim" : inner_optimizer .state_dict (),
135
177
"outer_optim" : outer_optimizer .state_dict (),
136
178
}
@@ -194,18 +236,24 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
194
236
195
237
criterion = nn .CrossEntropyLoss ()
196
238
all_state_dicts = {}
239
+
240
+ if "sync_every" not in diloco_args :
241
+ diloco_args ["sync_every" ] = 2
242
+
197
243
with DiLoCo (
198
244
manager ,
199
- [m ],
245
+ [layer for layer in m . layers ],
200
246
inner_optimizer ,
201
247
outer_optimizer ,
202
248
backup_device = device ,
203
- sync_every = 2 ,
249
+ ** diloco_args ,
204
250
) as diloco :
205
251
while True :
206
252
manager_curr_step = manager .current_step ()
207
253
if manager_curr_step not in all_state_dicts :
208
- all_state_dicts [manager_curr_step ] = copy .deepcopy (state_dict ())
254
+ all_state_dicts [manager_curr_step ] = copy .deepcopy (
255
+ manager ._manager_state_dict ()
256
+ )
209
257
batch_size = 1
210
258
inputs = m .get_rand_inputs (batch_size , device = device )
211
259
labels = m .get_rand_labels (batch_size , device = device )
@@ -308,7 +356,7 @@ def test_diloco_healthy(self, use_cuda: bool) -> None:
308
356
309
357
torch .manual_seed (42 )
310
358
# Initialize the model so we can pass in the state_dict
311
- m : nn .Module = MyModel (2 , 3 )
359
+ m : nn .Module = MultiMyModel (2 , 3 , 1 )
312
360
313
361
with ThreadPoolExecutor (max_workers = num_replicas ) as executor :
314
362
for replica_id in range (num_replicas ):
@@ -341,16 +389,18 @@ def test_diloco_healthy(self, use_cuda: bool) -> None:
341
389
for step , state_dict in rep1 .items ():
342
390
# inner optimizer will be different, outer optimizer and model should be the same
343
391
torch .testing .assert_close (
344
- state_dict ["model" ],
345
- rep0 [step ]["model" ],
392
+ state_dict ["user" ][ "default" ][ " model" ],
393
+ rep0 [step ]["user" ][ "default" ][ " model" ],
346
394
check_device = False ,
347
395
)
348
396
torch .testing .assert_close (
349
- state_dict ["outer_optim" ],
350
- rep0 [step ]["outer_optim" ],
397
+ state_dict ["user" ][ "default" ][ " outer_optim" ],
398
+ rep0 [step ]["user" ][ "default" ][ " outer_optim" ],
351
399
check_device = False ,
352
400
)
353
401
402
+ # pyre-fixme[56]: Pyre was not able to infer the type of argument
403
+ @skipIf (sys .platform == "darwin" , "not reliable on mac" )
354
404
@parameterized .expand (
355
405
[
356
406
# (True,),
@@ -362,12 +412,6 @@ def test_diloco_recovery(self, use_cuda: bool) -> None:
362
412
if use_cuda and torch .cuda .device_count () < 2 :
363
413
self .skipTest ("Not enough GPUs for CUDA test" )
364
414
365
- if platform == "darwin" :
366
- # TODO: This is likely because of Gloo not releasing GIL.
367
- # Fix in: https://github.com/pytorch/pytorch/pull/154976
368
- # Once this makes it to a stable package, we can re-enable this test.
369
- self .skipTest ("Known issue in Gloo" )
370
-
371
415
lighthouse = LighthouseServer (
372
416
bind = "[::]:0" ,
373
417
min_replicas = 2 ,
@@ -382,7 +426,7 @@ def test_diloco_recovery(self, use_cuda: bool) -> None:
382
426
383
427
torch .manual_seed (42 )
384
428
# Initialize the model so we can pass in the state_dict
385
- m : nn .Module = MyModel (2 , 3 )
429
+ m : nn .Module = MultiMyModel (2 , 3 , 1 )
386
430
387
431
with ThreadPoolExecutor (max_workers = num_replicas ) as executor :
388
432
for replica_id , failure_injector in zip (
@@ -431,13 +475,101 @@ def test_diloco_recovery(self, use_cuda: bool) -> None:
431
475
# Outer optimizer and global model should be the same
432
476
433
477
torch .testing .assert_close (
434
- rep1 [step ]["original_params" ],
435
- rep0 [step ]["original_params" ],
478
+ rep1 [step ]["user" ]["default" ]["original_params" ],
479
+ rep0 [step ]["user" ]["default" ]["original_params" ],
480
+ check_device = False ,
481
+ )
482
+ torch .testing .assert_close (
483
+ rep1 [step ]["user" ]["default" ]["outer_optim" ],
484
+ rep0 [step ]["user" ]["default" ]["outer_optim" ],
485
+ check_device = False ,
486
+ )
487
+ self .assertEqual (failure_injectors [1 ].count , 1 )
488
+
489
+ # pyre-fixme[56]: Pyre was not able to infer the type of argument
490
+ @skipIf (sys .platform == "darwin" , "not reliable on mac" )
491
+ @parameterized .expand (
492
+ [
493
+ # (True,),
494
+ (False ,),
495
+ ]
496
+ )
497
+ def test_streaming_diloco_recovery (self , use_cuda : bool ) -> None :
498
+ # Skip the test if use_cuda is True and there are not enough GPUs
499
+ if use_cuda and torch .cuda .device_count () < 2 :
500
+ self .skipTest ("Not enough GPUs for CUDA test" )
501
+
502
+ lighthouse = LighthouseServer (
503
+ bind = "[::]:0" ,
504
+ min_replicas = 2 ,
505
+ )
506
+ num_replicas = 2
507
+ futures = []
508
+
509
+ failure_injectors = [
510
+ FailureInjector (),
511
+ FailureInjector ().fail_at (0 , 2 ),
512
+ ]
513
+
514
+ torch .manual_seed (42 )
515
+ # Initialize the model so we can pass in the state_dict
516
+ m : nn .Module = MultiMyModel (2 , 3 , 2 )
517
+
518
+ with ThreadPoolExecutor (max_workers = num_replicas ) as executor :
519
+ for replica_id , failure_injector in zip (
520
+ range (num_replicas ), failure_injectors
521
+ ):
522
+ runner = Runner (
523
+ replica_id = replica_id ,
524
+ num_replicas = num_replicas ,
525
+ lighthouse_address = lighthouse .address (),
526
+ failure_injector = failure_injector ,
527
+ train_loop = diloco_train_loop ,
528
+ train_loop_args = {
529
+ "model_state_dict" : m .state_dict (),
530
+ "n_fragments" : 2 ,
531
+ "diloco_args" : {
532
+ "fragment_sync_delay" : 1 ,
533
+ "sync_every" : 4 ,
534
+ },
535
+ },
536
+ )
537
+ futures .append (executor .submit (runner .run_replica ))
538
+
539
+ state_dicts = []
540
+
541
+ for fut in as_completed (futures ):
542
+ continue
543
+
544
+ for fut in futures :
545
+ try :
546
+ state_dicts .append (fut .result ()[0 ])
547
+ except Exception as e :
548
+ print (e )
549
+ raise
550
+
551
+ lighthouse .shutdown ()
552
+
553
+ rep0 , rep1 = state_dicts
554
+
555
+ for step in rep1 .keys ():
556
+ if step == 2 :
557
+ # Replica 0 should have reset its `local_step` after failure
558
+ self .assertEqual (rep1 [step ]["user" ]["local_step" ], 0 )
559
+ self .assertEqual (rep0 [step ]["user" ]["local_step" ], 5 )
560
+ else :
561
+ self .assertEqual (
562
+ rep0 [step ]["user" ]["local_step" ], rep1 [step ]["user" ]["local_step" ]
563
+ )
564
+
565
+ torch .testing .assert_close (
566
+ rep1 [step ]["user" ]["default" ]["original_params" ],
567
+ rep0 [step ]["user" ]["default" ]["original_params" ],
436
568
check_device = False ,
437
569
)
438
570
torch .testing .assert_close (
439
- rep1 [step ]["outer_optim" ],
440
- rep0 [step ]["outer_optim" ],
571
+ rep1 [step ]["user" ][ "default" ][ " outer_optim" ],
572
+ rep0 [step ]["user" ][ "default" ][ " outer_optim" ],
441
573
check_device = False ,
442
574
)
443
575
self .assertEqual (failure_injectors [1 ].count , 1 )
0 commit comments