Skip to content

Commit ddf4d43

Browse files
iden-kalemajfacebook-github-bot
authored andcommitted
Add multi_gpu test for ghost clipping (#665)
Summary: Pull Request resolved: #665 Modify the existing `multigpu_gradcheck.py` test to check gradient correctness for ghost clipping in a distributed setting. Differential Revision: D60840755
1 parent f2a591a commit ddf4d43

File tree

1 file changed

+53
-32
lines changed

1 file changed

+53
-32
lines changed

Diff for: opacus/tests/multigpu_gradcheck.py

+53-32
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import itertools
1617
import os
1718
import sys
1819
import unittest
@@ -29,6 +30,10 @@
2930
SimpleDistributedPerLayerOptimizer,
3031
)
3132
from opacus.optimizers.ddpoptimizer import DistributedDPOptimizer
33+
from opacus.optimizers.ddpoptimizer_fast_gradient_clipping import (
34+
DistributedDPOptimizerFastGradientClipping,
35+
)
36+
from opacus.utils.fast_gradient_clipping_utils import double_backward
3237
from torch.nn.parallel import DistributedDataParallel as DDP
3338
from torch.utils.data import DataLoader, TensorDataset
3439
from torch.utils.data.distributed import DistributedSampler
@@ -84,8 +89,10 @@ def demo_basic(rank, weight, world_size, dp, clipping, grad_sample_mode):
8489

8590
dataset = TensorDataset(data, labels)
8691

87-
loss_fn = nn.MSELoss()
88-
if dp and clipping == "flat":
92+
reduction = "none" if dp and clipping == "ghost" else "mean"
93+
loss_fn = nn.CrossEntropyLoss(reduction=reduction)
94+
95+
if dp and clipping in ["flat", "ghost"]:
8996
ddp_model = DPDDP(model)
9097
else:
9198
ddp_model = DDP(model, device_ids=[rank])
@@ -115,15 +122,24 @@ def demo_basic(rank, weight, world_size, dp, clipping, grad_sample_mode):
115122
optimizer,
116123
(DistributedPerLayerOptimizer, SimpleDistributedPerLayerOptimizer),
117124
)
125+
elif clipping == "ghost":
126+
assert isinstance(optimizer, DistributedDPOptimizerFastGradientClipping)
118127
else:
119128
assert isinstance(optimizer, DistributedDPOptimizer)
120129

121130
for x, y in data_loader:
122-
outputs = ddp_model(x.to(rank))
123-
loss = loss_fn(outputs, y)
124-
optimizer.zero_grad()
125-
loss.backward()
126-
optimizer.step()
131+
if dp and clipping == "ghost":
132+
ddp_model.enable_hooks()
133+
outputs = ddp_model(x.to(rank))
134+
loss_per_sample = loss_fn(outputs, y)
135+
double_backward(ddp_model, optimizer, loss_per_sample)
136+
optimizer.step()
137+
else:
138+
outputs = ddp_model(x.to(rank))
139+
loss = loss_fn(outputs, y)
140+
optimizer.zero_grad()
141+
loss.backward()
142+
optimizer.step()
127143
break
128144

129145
weight.copy_(model.net1.weight.data.cpu())
@@ -141,33 +157,38 @@ def run_demo(demo_fn, weight, world_size, dp, clipping, grad_sample_mode):
141157

142158
class GradientComputationTest(unittest.TestCase):
143159
def test_gradient_correct(self) -> None:
144-
# Tests that gradient is the same with DP or with DDP
160+
# Tests that gradient is the same with DP or without DDP
145161
n_gpus = torch.cuda.device_count()
146162
self.assertTrue(
147163
n_gpus >= 2, f"Need at least 2 gpus but was provided only {n_gpus}."
148164
)
149165

150-
for clipping in ["flat", "per_layer"]:
151-
for grad_sample_mode in ["hooks", "ew"]:
152-
weight_dp, weight_nodp = torch.zeros(10, 10), torch.zeros(10, 10)
153-
154-
run_demo(
155-
demo_basic,
156-
weight_dp,
157-
2,
158-
dp=True,
159-
clipping=clipping,
160-
grad_sample_mode=grad_sample_mode,
161-
)
162-
run_demo(
163-
demo_basic,
164-
weight_nodp,
165-
2,
166-
dp=False,
167-
clipping=None,
168-
grad_sample_mode=None,
169-
)
170-
171-
self.assertTrue(
172-
torch.allclose(weight_dp, weight_nodp, atol=1e-5, rtol=1e-3)
173-
)
166+
clipping_grad_sample_pairs = list(
167+
itertools.product(["flat", "per_layer"], ["hooks", "ew"])
168+
)
169+
clipping_grad_sample_pairs.append(("ghost", "ghost"))
170+
171+
for clipping, grad_sample_mode in clipping_grad_sample_pairs:
172+
173+
weight_dp, weight_nodp = torch.zeros(10, 10), torch.zeros(10, 10)
174+
175+
run_demo(
176+
demo_basic,
177+
weight_dp,
178+
2,
179+
dp=True,
180+
clipping=clipping,
181+
grad_sample_mode=grad_sample_mode,
182+
)
183+
run_demo(
184+
demo_basic,
185+
weight_nodp,
186+
2,
187+
dp=False,
188+
clipping=None,
189+
grad_sample_mode=None,
190+
)
191+
192+
self.assertTrue(
193+
torch.allclose(weight_dp, weight_nodp, atol=1e-5, rtol=1e-3)
194+
)

0 commit comments

Comments
 (0)