Skip to content

Commit 40da9f1

Browse files
HuanyuZhangfacebook-github-bot
authored andcommitted
Reduce module size or the number of steps to avoid over-time tests (#739)
Summary: As titled. Facebook The buck test has the limit of 10 mins. To avoid overtime failure, we slightly reduce the number of parameters or the number of repetitions. This won't reduce the credibility of the tests. Differential Revision: D70707205
1 parent 86b4ab4 commit 40da9f1

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

Diff for: opacus/tests/grad_sample_module_fast_gradient_clipping_test.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ def __init__(self):
5454
super(SampleModule, self).__init__()
5555
self.fc1 = nn.Linear(2, 2)
5656
self.fc3 = nn.Linear(2, 1024)
57-
self.fc4 = nn.Linear(1024, 1024)
58-
self.fc5 = nn.Linear(1024, 1)
57+
self.fc4 = nn.Linear(1024, 10)
58+
self.fc5 = nn.Linear(10, 1)
5959
self.layer_norm = nn.LayerNorm(2)
6060

6161
def forward(self, x):
@@ -119,7 +119,7 @@ def setUp_data_sequantial(self, size, length, dim):
119119

120120
@given(
121121
size=st.sampled_from([10]),
122-
length=st.sampled_from([1]),
122+
length=st.sampled_from([5]),
123123
dim=st.sampled_from([2]),
124124
)
125125
@settings(deadline=1000000)
@@ -195,12 +195,12 @@ def test_norm_calculation_fast_gradient_clipping(self, size, length, dim):
195195
diff = flat_norms_normal - flat_norms_gc
196196

197197
logging.info(f"Diff = {diff}"),
198-
msg = "Fail: Gradients from vanilla DP-SGD and from fast gradient clipping are different"
198+
msg = "Fail: Per-sample gradient norms from vanilla DP-SGD and from fast gradient clipping are different"
199199
assert torch.allclose(flat_norms_normal, flat_norms_gc, atol=1e-3), msg
200200

201201
@given(
202202
size=st.sampled_from([10]),
203-
length=st.sampled_from([1, 5]),
203+
length=st.sampled_from([5]),
204204
dim=st.sampled_from([2]),
205205
)
206206
@settings(deadline=1000000)

Diff for: opacus/tests/privacy_engine_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def _compare_to_vanilla(
268268
do_clip=st.booleans(),
269269
do_noise=st.booleans(),
270270
use_closure=st.booleans(),
271-
max_steps=st.sampled_from([1, 4]),
271+
max_steps=st.sampled_from([1, 3]),
272272
)
273273
@settings(suppress_health_check=list(HealthCheck), deadline=None)
274274
def test_compare_to_vanilla(
@@ -660,7 +660,7 @@ def test_checkpoints(
660660

661661
@given(
662662
noise_multiplier=st.floats(0.5, 5.0),
663-
max_steps=st.integers(8, 10),
663+
max_steps=st.integers(3, 5),
664664
secure_mode=st.just(False), # TODO: enable after fixing torchcsprng build
665665
)
666666
@settings(suppress_health_check=list(HealthCheck), deadline=None)

0 commit comments

Comments
 (0)