3
3
import os
4
4
import re
5
5
import sys
6
+ import threading
6
7
import traceback
7
8
from concurrent .futures import ThreadPoolExecutor , as_completed
8
9
from contextlib import ExitStack
10
+ from dataclasses import field
9
11
from datetime import timedelta
10
12
from typing import Any , Dict
11
13
from unittest import TestCase , skipIf
12
14
13
15
import torch
14
16
from parameterized import parameterized
15
17
from torch import nn , optim
18
+ from torch .distributed .pipelining import SplitPoint , pipeline
16
19
from torch .distributed .tensor import DTensor , Replicate
17
20
18
21
from torchft ._torchft import LighthouseServer
19
22
from torchft .device_mesh import ft_init_device_mesh
20
23
from torchft .local_sgd import DiLoCo , LocalSGD
21
24
from torchft .manager import Manager
22
- from torchft .manager_integ_test import FailureInjector , MyModel , Runner
25
+ from torchft .manager_integ_test import BarrierInjector , FailureInjector , MyModel , Runner
23
26
from torchft .process_group import ProcessGroupBabyNCCL , ProcessGroupGloo
24
27
25
28
logger : logging .Logger = logging .getLogger (__name__ )
@@ -254,6 +257,10 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
254
257
all_state_dicts [manager_curr_step ] = copy .deepcopy (
255
258
manager ._manager_state_dict ()
256
259
)
260
+
261
+ if runner .barrier_injector is not None :
262
+ runner .barrier_injector .check (manager_curr_step )
263
+
257
264
batch_size = 1
258
265
inputs = m .get_rand_inputs (batch_size , device = device )
259
266
labels = m .get_rand_labels (batch_size , device = device )
@@ -276,6 +283,26 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
276
283
return {}
277
284
278
285
286
+ def assert_equal_global_state (
287
+ rep0 : dict [str , dict [str , dict [str , dict [str , object ]]]],
288
+ rep1 : dict [str , dict [str , dict [str , dict [str , object ]]]],
289
+ ) -> None :
290
+ """
291
+ Asserts that the global state of the two replicas are equal
292
+ """
293
+ for step in rep0 .keys ():
294
+ torch .testing .assert_close (
295
+ rep1 [step ]["user" ]["default" ]["original_params" ],
296
+ rep0 [step ]["user" ]["default" ]["original_params" ],
297
+ check_device = False ,
298
+ )
299
+ torch .testing .assert_close (
300
+ rep1 [step ]["user" ]["default" ]["outer_optim" ],
301
+ rep0 [step ]["user" ]["default" ]["outer_optim" ],
302
+ check_device = False ,
303
+ )
304
+
305
+
279
306
class LocalSGDIntegTest (TestCase ):
280
307
# TODO: race condition due to using NCCL in threads causes manager allreduce to sometimes not be correct
281
308
# Because of that the test is disabled for cuda
@@ -447,6 +474,9 @@ def test_diloco_recovery(self, use_cuda: bool) -> None:
447
474
state_dicts = []
448
475
449
476
for fut in as_completed (futures ):
477
+ continue
478
+
479
+ for fut in futures :
450
480
try :
451
481
state_dicts .append (fut .result ()[0 ])
452
482
except Exception as e :
@@ -457,33 +487,23 @@ def test_diloco_recovery(self, use_cuda: bool) -> None:
457
487
458
488
rep0 , rep1 = state_dicts
459
489
460
- for step in rep0 . keys ():
461
- # Inner optimizer and local model parameters will be different e.g.
462
- # with 2 replicas r1 and r2, we sync every 2 steps
463
- #
464
- # - Manager Step 1
465
- # - Step 1 : r1 and r2 step
466
- # - Step 2: r1 and r2 step, sync the model, quorum succeeds
467
- # - Manager Step 2
468
- # - Step 1: r1 steps but r2 fails
469
- # - Step 2:
470
- # - r1 steps, sync fails because r2 is down
471
- # - r1 recovers r2 from the model state at this step
472
- # that is different from the model for r1 at the beginning
473
- # of step Manager Step 2
474
- #
475
- # Outer optimizer and global model should be the same
490
+ # Inner optimizer and local model parameters will be different e.g.
491
+ # with 2 replicas r1 and r2, we sync every 2 steps
492
+ #
493
+ # - Manager Step 1
494
+ # - Step 1: r1 and r2 step
495
+ # - Step 2 : r1 and r2 step, sync the model, quorum succeeds
496
+ # - Manager Step 2
497
+ # - Step 1: r1 steps but r2 fails
498
+ # - Step 2:
499
+ # - r1 steps, sync fails because r2 is down
500
+ # - r1 recovers r2 from the model state at this step
501
+ # that is different from the model for r1 at the beginning
502
+ # of step Manager Step 2
503
+ #
504
+ # Outer optimizer and global model should be the same
505
+ assert_equal_global_state ( rep1 , rep0 )
476
506
477
- torch .testing .assert_close (
478
- rep1 [step ]["user" ]["default" ]["original_params" ],
479
- rep0 [step ]["user" ]["default" ]["original_params" ],
480
- check_device = False ,
481
- )
482
- torch .testing .assert_close (
483
- rep1 [step ]["user" ]["default" ]["outer_optim" ],
484
- rep0 [step ]["user" ]["default" ]["outer_optim" ],
485
- check_device = False ,
486
- )
487
507
self .assertEqual (failure_injectors [1 ].count , 1 )
488
508
489
509
# pyre-fixme[56]: Pyre was not able to infer the type of argument
@@ -552,6 +572,8 @@ def test_streaming_diloco_recovery(self, use_cuda: bool) -> None:
552
572
553
573
rep0 , rep1 = state_dicts
554
574
575
+ assert_equal_global_state (rep1 , rep0 )
576
+
555
577
for step in rep1 .keys ():
556
578
if step == 2 :
557
579
# Replica 0 should have reset its `local_step` after failure
@@ -562,14 +584,93 @@ def test_streaming_diloco_recovery(self, use_cuda: bool) -> None:
562
584
rep0 [step ]["user" ]["local_step" ], rep1 [step ]["user" ]["local_step" ]
563
585
)
564
586
565
- torch .testing .assert_close (
566
- rep1 [step ]["user" ]["default" ]["original_params" ],
567
- rep0 [step ]["user" ]["default" ]["original_params" ],
568
- check_device = False ,
587
+ self .assertEqual (failure_injectors [1 ].count , 1 )
588
+
589
+ CONFIG : list [tuple [bool , int , int ]] = [
590
+ (use_cuda , n_fragments , fragment_sync_delay )
591
+ for use_cuda in [True , False ]
592
+ for n_fragments in [1 , 2 ]
593
+ for fragment_sync_delay in [0 , 1 ]
594
+ ]
595
+
596
+ # pyre-fixme[56]: Pyre was not able to infer the type of argument
597
+ @skipIf (sys .platform == "darwin" , "not reliable on mac" )
598
+ @parameterized .expand (CONFIG )
599
+ def test_streaming_diloco_upscale (
600
+ self , use_cuda : bool , n_fragments : int , fragment_sync_delay : int
601
+ ) -> None :
602
+ # Skip the test if use_cuda is True and there are not enough GPUs
603
+ if use_cuda and torch .cuda .device_count () < 2 :
604
+ self .skipTest ("Not enough GPUs for CUDA test" )
605
+
606
+ lighthouse = LighthouseServer (
607
+ bind = "[::]:0" ,
608
+ min_replicas = 2 ,
609
+ )
610
+ num_replicas = 3
611
+ futures = []
612
+ executors = []
613
+
614
+ barrier = threading .Barrier (num_replicas )
615
+
616
+ barrier_injectors = [
617
+ # Make this replica join after other replicas have made 2 steps
618
+ BarrierInjector ().barrier_at (0 , barrier ),
619
+ BarrierInjector ().barrier_at (2 , barrier ),
620
+ BarrierInjector ().barrier_at (2 , barrier ),
621
+ ]
622
+
623
+ torch .manual_seed (42 )
624
+ # Initialize the model so we can pass in the state_dict
625
+ m : nn .Module = MultiMyModel (2 , 3 , n_fragments )
626
+
627
+ for replica_id , barrier_injector in zip (range (num_replicas ), barrier_injectors ):
628
+ executor = ThreadPoolExecutor (max_workers = 1 )
629
+ executors .append (executor )
630
+ runner = Runner (
631
+ replica_id = replica_id ,
632
+ num_replicas = num_replicas ,
633
+ lighthouse_address = lighthouse .address (),
634
+ failure_injector = FailureInjector (),
635
+ barrier_injector = barrier_injector ,
636
+ train_loop = diloco_train_loop ,
637
+ train_loop_args = {
638
+ "model_state_dict" : m .state_dict (),
639
+ "n_fragments" : n_fragments ,
640
+ "diloco_args" : {
641
+ "fragment_sync_delay" : fragment_sync_delay ,
642
+ "sync_every" : 4 ,
643
+ },
644
+ },
569
645
)
570
- torch .testing .assert_close (
571
- rep1 [step ]["user" ]["default" ]["outer_optim" ],
572
- rep0 [step ]["user" ]["default" ]["outer_optim" ],
573
- check_device = False ,
646
+ futures .append (executor .submit (runner .run_replica ))
647
+
648
+ state_dicts = []
649
+
650
+ for fut in as_completed (futures ):
651
+ continue
652
+
653
+ for fut in futures :
654
+ try :
655
+ state_dicts .append (fut .result ()[0 ])
656
+ except Exception as e :
657
+ print (e )
658
+ raise
659
+
660
+ lighthouse .shutdown ()
661
+
662
+ rep0 , rep1 , rep2 = state_dicts
663
+
664
+ assert_equal_global_state (rep0 , rep1 )
665
+ assert_equal_global_state (rep0 , rep2 )
666
+
667
+ for step in rep0 .keys ():
668
+ self .assertEqual (
669
+ rep0 [step ]["user" ]["local_step" ], rep1 [step ]["user" ]["local_step" ]
574
670
)
575
- self .assertEqual (failure_injectors [1 ].count , 1 )
671
+ self .assertEqual (
672
+ rep1 [step ]["user" ]["local_step" ], rep2 [step ]["user" ]["local_step" ]
673
+ )
674
+
675
+ for barrier_injector in barrier_injectors :
676
+ self .assertEqual (barrier_injector .count , 1 )
0 commit comments