@@ -52,6 +52,16 @@ def _copy_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Ten
52
52
return {name : value .clone ().detach () for name , value in state_dict .items ()}
53
53
54
54
55
+ class TinyModel (nn .Module ):
56
+ def __init__ (self ) -> None :
57
+ super ().__init__ ()
58
+ self .w1 = nn .Parameter (torch .tensor ([1.0 , 2.0 ]))
59
+ self .w2 = nn .Parameter (torch .tensor ([3.0 , 4.0 , 5.0 ]))
60
+
61
+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
62
+ return x @ self .w1 .unsqueeze (0 ).T + self .w2 .sum ()
63
+
64
+
55
65
class LocalSGDTest (TestCase ):
56
66
def test_local_sgd_healthy (self ) -> None :
57
67
model = SimpleModel ()
@@ -216,24 +226,10 @@ def test_diloco_allreduce_call_efficiency(
216
226
self .assertEqual (int (allreduce_calls ), int (param_count ))
217
227
218
228
def test_bucketization_correctness (self ) -> None :
219
- class TinyModel (nn .Module ):
220
- def __init__ (self ):
221
- super ().__init__ ()
222
- self .w1 = nn .Parameter (torch .tensor ([1.0 , 2.0 ]))
223
- self .w2 = nn .Parameter (torch .tensor ([3.0 , 4.0 , 5.0 ]))
224
-
225
- def forward (self , x ):
226
- return x @ self .w1 .unsqueeze (0 ).T + self .w2 .sum ()
227
-
228
229
model = TinyModel ()
229
230
inner_opt = torch .optim .SGD (model .parameters (), lr = 0.1 )
230
231
outer_opt = torch .optim .SGD (model .parameters (), lr = 0.1 )
231
232
232
- # Manually assign fake gradients
233
- grads = [torch .tensor ([1.0 , 2.0 ]), torch .tensor ([3.0 , 4.0 , 5.0 ])]
234
- for p , g in zip (model .parameters (), grads ):
235
- p .grad = g .clone ()
236
-
237
233
manager = create_autospec (Manager )
238
234
manager ._use_async_quorum = False
239
235
manager .should_commit .return_value = True
@@ -254,10 +250,71 @@ def fake_allreduce(
254
250
)
255
251
diloco ._fragments [0 ].bucket_cap_mb = 10 * 1024 * 1024
256
252
253
+ # Manually assign fake gradients
254
+ grads = [torch .tensor ([1.0 , 2.0 ]), torch .tensor ([3.0 , 4.0 , 5.0 ])]
255
+ for g , (name , param ) in zip (grads , model .named_parameters ()):
256
+ diloco ._fragments [0 ]._grads [name ] = g .clone ()
257
+
257
258
# Run only bucketized logic
258
259
diloco ._fragments [0 ]._average_grads ()
259
260
261
+ # The parameter gradients should not be set
262
+ for param in model .parameters ():
263
+ self .assertEqual (param .grad , None )
264
+
265
+ diloco ._fragments [0 ]._set_grads ()
266
+
260
267
# Expect grads to have been doubled
261
268
expected_grads = [g * 2 for g in grads ]
262
269
for param , expected in zip (model .parameters (), expected_grads ):
263
270
torch .testing .assert_close (param .grad , expected , rtol = 1e-5 , atol = 1e-8 )
271
+
272
+ def test_gradient_correctness (self ) -> None :
273
+ model = TinyModel ()
274
+ inner_opt = torch .optim .SGD (model .parameters (), lr = 0.1 )
275
+ outer_opt = torch .optim .SGD (model .parameters (), lr = 0.1 )
276
+
277
+ manager = create_autospec (Manager )
278
+ manager ._use_async_quorum = False
279
+ manager .should_commit .return_value = True
280
+
281
+ # Define fake allreduce: multiplies buffer by 2
282
+ def fake_allreduce (
283
+ tensor : Tensor , should_quantize : bool
284
+ ) -> torch .futures .Future [Tensor ]:
285
+ tensor .mul_ (2 )
286
+ fut = torch .futures .Future () # pyre-fixme[29]: not a function
287
+ fut .set_result (tensor )
288
+ return fut
289
+
290
+ manager .allreduce .side_effect = fake_allreduce
291
+
292
+ diloco = DiLoCo (manager , [model ], inner_opt , outer_opt , sync_every = 2 )
293
+
294
+ # save original parameters
295
+ diloco ._fragments [0 ].save_parameters ()
296
+
297
+ # change the model's parameters
298
+ for p in model .parameters ():
299
+ p .data .add_ (2 )
300
+
301
+ # calculate and set the gradients
302
+ diloco ._fragments [0 ]._save_grads ()
303
+
304
+ # calculate
305
+ diloco ._fragments [0 ]._average_grads ()
306
+
307
+ # The parameter gradients should not be set
308
+ for param in model .parameters ():
309
+ self .assertEqual (param .grad , None )
310
+
311
+ diloco ._fragments [0 ]._set_grads ()
312
+
313
+ # we added 2 to the parameters, then multiplied the gradients by 2
314
+ # so we should expect the model's gradient to be 4
315
+ expected_grad = 4
316
+ for param in model .parameters ():
317
+ assert param .grad is not None
318
+ t = torch .empty_like (param .grad )
319
+ t .fill_ (expected_grad )
320
+ torch .testing .assert_close (param .grad , t )
0 commit comments