Skip to content

Commit 87fbc95

Browse files
authored
fix gradient allreduce (#215)
Summary: - fix setting `_local_tensor` of a dtensor directly - fix allreduce bucketized to not use `parameter.grad` - simplify some code Test Plan: - added a test to validate the gradient are saved and set correctly - the previous test in `local_sgd_test` fails because allreduce is not performed on `param.grad` - updated the test to first set the grads, then load the grads to make sure they reflect the allreduce result
1 parent 97b8d5c commit 87fbc95

File tree

2 files changed

+111
-42
lines changed

2 files changed

+111
-42
lines changed

torchft/local_sgd.py

Lines changed: 40 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -257,17 +257,41 @@ def restore_parameters(self) -> None:
257257
else:
258258
p.data.copy_(self.original_parameters[name], non_blocking=False)
259259

260+
def _save_grads(self) -> None:
261+
"""
262+
Saves pseudo-gradients of the parameters
263+
"""
264+
with torch.no_grad():
265+
for name, p in self._model_fragment.named_parameters():
266+
if isinstance(p, DTensor):
267+
local_param = p.to_local()
268+
else:
269+
local_param = p
270+
pseudogradient = local_param - self.original_parameters[name].to(
271+
p.device
272+
)
273+
self._grads[name] = pseudogradient
274+
260275
def _set_grads(self) -> None:
261276
"""
262277
Sets the gradients of the model fragment from the allreduce result
263278
"""
264-
for name, p in self._model_fragment.named_parameters():
265-
if isinstance(p, DTensor):
266-
p.grad._local_tensor = self._grads[name]
267-
else:
268-
p.grad = self._grads[name]
279+
with torch.no_grad():
280+
for name, p in self._model_fragment.named_parameters():
281+
# avoid copying the gradient, it should be on the same device
282+
if isinstance(p, DTensor):
283+
p.grad = DTensor.from_local(
284+
self._grads[name],
285+
p.device_mesh,
286+
p.placements,
287+
shape=p.shape,
288+
stride=p.stride(),
289+
)
290+
else:
291+
p.grad = self._grads[name]
269292

270-
del self._grads[name]
293+
# No longer needed
294+
del self._grads[name]
271295

272296
@torch.profiler.record_function("torchft::local_sgd::wait")
273297
def wait(self) -> None:
@@ -304,14 +328,9 @@ def prepare_sync(self) -> None:
304328
Calculate the pseugradient, average them across the manager group and starts
305329
allreduce on the pseudo-gradients but doesn't wait for it to finish.
306330
"""
307-
# Set the .grad field of each parameter to its pseudogradient
308-
for name, p in self._model_fragment.named_parameters():
309-
local_param = extract_local_tensor(p.data)
310-
pseudogradient = local_param - self.original_parameters[name].to(p.device)
311-
if isinstance(p, DTensor):
312-
self._grads[name] = pseudogradient
313-
else:
314-
self._grads[name] = pseudogradient
331+
self._save_grads()
332+
333+
assert len(self._allreduce_futures) == 0
315334

316335
# Make sure tensors are available to `_stream`
317336
if self._stream is not None:
@@ -371,18 +390,12 @@ def _allreduce_per_param(self) -> None:
371390
"""Performs allreduce on each gradient tensor separately (original method)."""
372391
for name, p in self._model_fragment.named_parameters():
373392
# Perform allreduce on the pseudogradients
374-
assert p.grad is not None
375-
if isinstance(p, DTensor):
376-
work = self._manager.allreduce(
377-
self._grads[name], should_quantize=self.should_quantize
378-
)
379-
else:
380-
work = self._manager.allreduce(
381-
self._grads[name], should_quantize=self.should_quantize
382-
)
393+
work = self._manager.allreduce(
394+
self._grads[name], should_quantize=self.should_quantize
395+
)
383396
self._allreduce_futures.append(work)
384397

385-
def bucketize_and_allreduce(
398+
def _bucketize_and_allreduce(
386399
self,
387400
tensors: List[torch.Tensor],
388401
bucket_size_bytes: int,
@@ -439,10 +452,9 @@ def _allreduce_bucketized(self) -> None:
439452
"""
440453
Averages gradients using bucketized allreduce with a fixed buffer.
441454
"""
442-
grads = [
443-
p.grad for p in self._model_fragment.parameters() if p.grad is not None
444-
]
445-
self.bucketize_and_allreduce(
455+
grads = list(self._grads.values())
456+
assert len(grads) > 0, "No gradients to allreduce"
457+
self._bucketize_and_allreduce(
446458
grads,
447459
bucket_size_bytes=self.bucket_cap_mb,
448460
)

torchft/local_sgd_test.py

Lines changed: 71 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,16 @@ def _copy_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Ten
5252
return {name: value.clone().detach() for name, value in state_dict.items()}
5353

5454

55+
class TinyModel(nn.Module):
56+
def __init__(self) -> None:
57+
super().__init__()
58+
self.w1 = nn.Parameter(torch.tensor([1.0, 2.0]))
59+
self.w2 = nn.Parameter(torch.tensor([3.0, 4.0, 5.0]))
60+
61+
def forward(self, x: torch.Tensor) -> torch.Tensor:
62+
return x @ self.w1.unsqueeze(0).T + self.w2.sum()
63+
64+
5565
class LocalSGDTest(TestCase):
5666
def test_local_sgd_healthy(self) -> None:
5767
model = SimpleModel()
@@ -216,24 +226,10 @@ def test_diloco_allreduce_call_efficiency(
216226
self.assertEqual(int(allreduce_calls), int(param_count))
217227

218228
def test_bucketization_correctness(self) -> None:
219-
class TinyModel(nn.Module):
220-
def __init__(self):
221-
super().__init__()
222-
self.w1 = nn.Parameter(torch.tensor([1.0, 2.0]))
223-
self.w2 = nn.Parameter(torch.tensor([3.0, 4.0, 5.0]))
224-
225-
def forward(self, x):
226-
return x @ self.w1.unsqueeze(0).T + self.w2.sum()
227-
228229
model = TinyModel()
229230
inner_opt = torch.optim.SGD(model.parameters(), lr=0.1)
230231
outer_opt = torch.optim.SGD(model.parameters(), lr=0.1)
231232

232-
# Manually assign fake gradients
233-
grads = [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0, 5.0])]
234-
for p, g in zip(model.parameters(), grads):
235-
p.grad = g.clone()
236-
237233
manager = create_autospec(Manager)
238234
manager._use_async_quorum = False
239235
manager.should_commit.return_value = True
@@ -254,10 +250,71 @@ def fake_allreduce(
254250
)
255251
diloco._fragments[0].bucket_cap_mb = 10 * 1024 * 1024
256252

253+
# Manually assign fake gradients
254+
grads = [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0, 5.0])]
255+
for g, (name, param) in zip(grads, model.named_parameters()):
256+
diloco._fragments[0]._grads[name] = g.clone()
257+
257258
# Run only bucketized logic
258259
diloco._fragments[0]._average_grads()
259260

261+
# The parameter gradients should not be set
262+
for param in model.parameters():
263+
self.assertEqual(param.grad, None)
264+
265+
diloco._fragments[0]._set_grads()
266+
260267
# Expect grads to have been doubled
261268
expected_grads = [g * 2 for g in grads]
262269
for param, expected in zip(model.parameters(), expected_grads):
263270
torch.testing.assert_close(param.grad, expected, rtol=1e-5, atol=1e-8)
271+
272+
def test_gradient_correctness(self) -> None:
273+
model = TinyModel()
274+
inner_opt = torch.optim.SGD(model.parameters(), lr=0.1)
275+
outer_opt = torch.optim.SGD(model.parameters(), lr=0.1)
276+
277+
manager = create_autospec(Manager)
278+
manager._use_async_quorum = False
279+
manager.should_commit.return_value = True
280+
281+
# Define fake allreduce: multiplies buffer by 2
282+
def fake_allreduce(
283+
tensor: Tensor, should_quantize: bool
284+
) -> torch.futures.Future[Tensor]:
285+
tensor.mul_(2)
286+
fut = torch.futures.Future() # pyre-fixme[29]: not a function
287+
fut.set_result(tensor)
288+
return fut
289+
290+
manager.allreduce.side_effect = fake_allreduce
291+
292+
diloco = DiLoCo(manager, [model], inner_opt, outer_opt, sync_every=2)
293+
294+
# save original parameters
295+
diloco._fragments[0].save_parameters()
296+
297+
# change the model's parameters
298+
for p in model.parameters():
299+
p.data.add_(2)
300+
301+
# calculate and set the gradients
302+
diloco._fragments[0]._save_grads()
303+
304+
# calculate
305+
diloco._fragments[0]._average_grads()
306+
307+
# The parameter gradients should not be set
308+
for param in model.parameters():
309+
self.assertEqual(param.grad, None)
310+
311+
diloco._fragments[0]._set_grads()
312+
313+
# we added 2 to the parameters, then multiplied the gradients by 2
314+
# so we should expect the model's gradient to be 4
315+
expected_grad = 4
316+
for param in model.parameters():
317+
assert param.grad is not None
318+
t = torch.empty_like(param.grad)
319+
t.fill_(expected_grad)
320+
torch.testing.assert_close(param.grad, t)

0 commit comments

Comments
 (0)