13
13
# See the License for the specific language governing permissions and
14
14
# limitations under the License.
15
15
16
+ import itertools
16
17
import os
17
18
import sys
18
19
import unittest
29
30
SimpleDistributedPerLayerOptimizer ,
30
31
)
31
32
from opacus .optimizers .ddpoptimizer import DistributedDPOptimizer
33
+ from opacus .optimizers .ddpoptimizer_fast_gradient_clipping import (
34
+ DistributedDPOptimizerFastGradientClipping ,
35
+ )
36
+ from opacus .utils .fast_gradient_clipping_utils import double_backward
32
37
from torch .nn .parallel import DistributedDataParallel as DDP
33
38
from torch .utils .data import DataLoader , TensorDataset
34
39
from torch .utils .data .distributed import DistributedSampler
@@ -84,8 +89,10 @@ def demo_basic(rank, weight, world_size, dp, clipping, grad_sample_mode):
84
89
85
90
dataset = TensorDataset (data , labels )
86
91
87
- loss_fn = nn .MSELoss ()
88
- if dp and clipping == "flat" :
92
+ reduction = "none" if dp and clipping == "ghost" else "mean"
93
+ loss_fn = nn .CrossEntropyLoss (reduction = reduction )
94
+
95
+ if dp and clipping in ["flat" , "ghost" ]:
89
96
ddp_model = DPDDP (model )
90
97
else :
91
98
ddp_model = DDP (model , device_ids = [rank ])
@@ -115,15 +122,24 @@ def demo_basic(rank, weight, world_size, dp, clipping, grad_sample_mode):
115
122
optimizer ,
116
123
(DistributedPerLayerOptimizer , SimpleDistributedPerLayerOptimizer ),
117
124
)
125
+ elif clipping == "ghost" :
126
+ assert isinstance (optimizer , DistributedDPOptimizerFastGradientClipping )
118
127
else :
119
128
assert isinstance (optimizer , DistributedDPOptimizer )
120
129
121
130
for x , y in data_loader :
122
- outputs = ddp_model (x .to (rank ))
123
- loss = loss_fn (outputs , y )
124
- optimizer .zero_grad ()
125
- loss .backward ()
126
- optimizer .step ()
131
+ if dp and clipping == "ghost" :
132
+ ddp_model .enable_hooks ()
133
+ outputs = ddp_model (x .to (rank ))
134
+ loss_per_sample = loss_fn (outputs , y )
135
+ double_backward (ddp_model , optimizer , loss_per_sample )
136
+ optimizer .step ()
137
+ else :
138
+ outputs = ddp_model (x .to (rank ))
139
+ loss = loss_fn (outputs , y )
140
+ optimizer .zero_grad ()
141
+ loss .backward ()
142
+ optimizer .step ()
127
143
break
128
144
129
145
weight .copy_ (model .net1 .weight .data .cpu ())
@@ -141,33 +157,38 @@ def run_demo(demo_fn, weight, world_size, dp, clipping, grad_sample_mode):
141
157
142
158
class GradientComputationTest (unittest .TestCase ):
143
159
def test_gradient_correct (self ) -> None :
144
- # Tests that gradient is the same with DP or with DDP
160
+ # Tests that gradient is the same with DP or without DDP
145
161
n_gpus = torch .cuda .device_count ()
146
162
self .assertTrue (
147
163
n_gpus >= 2 , f"Need at least 2 gpus but was provided only { n_gpus } ."
148
164
)
149
165
150
- for clipping in ["flat" , "per_layer" ]:
151
- for grad_sample_mode in ["hooks" , "ew" ]:
152
- weight_dp , weight_nodp = torch .zeros (10 , 10 ), torch .zeros (10 , 10 )
153
-
154
- run_demo (
155
- demo_basic ,
156
- weight_dp ,
157
- 2 ,
158
- dp = True ,
159
- clipping = clipping ,
160
- grad_sample_mode = grad_sample_mode ,
161
- )
162
- run_demo (
163
- demo_basic ,
164
- weight_nodp ,
165
- 2 ,
166
- dp = False ,
167
- clipping = None ,
168
- grad_sample_mode = None ,
169
- )
170
-
171
- self .assertTrue (
172
- torch .allclose (weight_dp , weight_nodp , atol = 1e-5 , rtol = 1e-3 )
173
- )
166
+ clipping_grad_sample_pairs = list (
167
+ itertools .product (["flat" , "per_layer" ], ["hooks" , "ew" ])
168
+ )
169
+ clipping_grad_sample_pairs .append (("ghost" , "ghost" ))
170
+
171
+ for clipping , grad_sample_mode in clipping_grad_sample_pairs :
172
+
173
+ weight_dp , weight_nodp = torch .zeros (10 , 10 ), torch .zeros (10 , 10 )
174
+
175
+ run_demo (
176
+ demo_basic ,
177
+ weight_dp ,
178
+ 2 ,
179
+ dp = True ,
180
+ clipping = clipping ,
181
+ grad_sample_mode = grad_sample_mode ,
182
+ )
183
+ run_demo (
184
+ demo_basic ,
185
+ weight_nodp ,
186
+ 2 ,
187
+ dp = False ,
188
+ clipping = None ,
189
+ grad_sample_mode = None ,
190
+ )
191
+
192
+ self .assertTrue (
193
+ torch .allclose (weight_dp , weight_nodp , atol = 1e-5 , rtol = 1e-3 )
194
+ )
0 commit comments