@@ -73,11 +73,11 @@ def setUp(self):
7373 torch .manual_seed (42 )
7474
7575 @abc .abstractmethod
76- def _init_data (self ):
76+ def _init_data (self ) -> None :
7777 pass
7878
7979 @abc .abstractmethod
80- def _init_model (self ):
80+ def _init_model (self ) -> None :
8181 pass
8282
8383 def _init_vanilla_training (
@@ -193,7 +193,7 @@ def closure():
193193 if max_steps and steps >= max_steps :
194194 break
195195
196- def test_basic (self ):
196+ def test_basic (self ) -> None :
197197 for opt_exclude_frozen in [True , False ]:
198198 with self .subTest (opt_exclude_frozen = opt_exclude_frozen ):
199199 model , optimizer , dl , _ = self ._init_private_training (
@@ -287,7 +287,7 @@ def test_compare_to_vanilla(
287287 max_steps = max_steps ,
288288 )
289289
290- def test_flat_clipping (self ):
290+ def test_flat_clipping (self ) -> None :
291291 self .BATCH_SIZE = 1
292292 max_grad_norm = 0.5
293293
@@ -314,7 +314,7 @@ def test_flat_clipping(self):
314314 self .assertAlmostEqual (clipped_grads .norm ().item (), max_grad_norm , places = 3 )
315315 self .assertGreater (non_clipped_grads .norm (), clipped_grads .norm ())
316316
317- def test_per_layer_clipping (self ):
317+ def test_per_layer_clipping (self ) -> None :
318318 self .BATCH_SIZE = 1
319319 max_grad_norm_per_layer = 1.0
320320
@@ -344,7 +344,7 @@ def test_per_layer_clipping(self):
344344 min (non_clipped_norm , max_grad_norm_per_layer ), clipped_norm , places = 3
345345 )
346346
347- def test_sample_grad_aggregation (self ):
347+ def test_sample_grad_aggregation (self ) -> None :
348348 """
349349 Check if final gradient is indeed an aggregation over per-sample gradients
350350 """
@@ -367,7 +367,7 @@ def test_sample_grad_aggregation(self):
367367 f"Param: { p_name } " ,
368368 )
369369
370- def test_noise_changes_every_time (self ):
370+ def test_noise_changes_every_time (self ) -> None :
371371 """
372372 Test that adding noise results in ever different model params.
373373 We disable clipping in this test by setting it to a very high threshold.
@@ -387,7 +387,7 @@ def test_noise_changes_every_time(self):
387387 for p0 , p1 in zip (first_run_params , second_run_params ):
388388 self .assertFalse (torch .allclose (p0 , p1 ))
389389
390- def test_get_compatible_module_inaction (self ):
390+ def test_get_compatible_module_inaction (self ) -> None :
391391 needs_no_replacement_module = nn .Linear (1 , 2 )
392392 fixed_module = PrivacyEngine .get_compatible_module (needs_no_replacement_module )
393393 self .assertFalse (fixed_module is needs_no_replacement_module )
@@ -397,7 +397,7 @@ def test_get_compatible_module_inaction(self):
397397 )
398398 )
399399
400- def test_model_validator (self ):
400+ def test_model_validator (self ) -> None :
401401 """
402402 Test that the privacy engine raises errors
403403 if there are unsupported modules
@@ -416,7 +416,7 @@ def test_model_validator(self):
416416 grad_sample_mode = self .GRAD_SAMPLE_MODE ,
417417 )
418418
419- def test_model_validator_after_fix (self ):
419+ def test_model_validator_after_fix (self ) -> None :
420420 """
421421 Test that the privacy engine fixes unsupported modules
422422 and succeeds.
@@ -435,7 +435,7 @@ def test_model_validator_after_fix(self):
435435 )
436436 self .assertTrue (1 , 1 )
437437
438- def test_make_private_with_epsilon (self ):
438+ def test_make_private_with_epsilon (self ) -> None :
439439 model , optimizer , dl = self ._init_vanilla_training ()
440440 target_eps = 2.0
441441 target_delta = 1e-5
@@ -458,7 +458,7 @@ def test_make_private_with_epsilon(self):
458458 target_eps , privacy_engine .get_epsilon (target_delta ), places = 2
459459 )
460460
461- def test_deterministic_run (self ):
461+ def test_deterministic_run (self ) -> None :
462462 """
463463 Tests that for 2 different models, secure seed can be fixed
464464 to produce same (deterministic) runs.
@@ -483,7 +483,7 @@ def test_deterministic_run(self):
483483 "Model parameters after deterministic run must match" ,
484484 )
485485
486- def test_validator_weight_update_check (self ):
486+ def test_validator_weight_update_check (self ) -> None :
487487 """
488488 Test that the privacy engine raises error if ModuleValidator.fix(model) is
489489 called after the optimizer is created
@@ -522,7 +522,7 @@ def test_validator_weight_update_check(self):
522522 grad_sample_mode = self .GRAD_SAMPLE_MODE ,
523523 )
524524
525- def test_parameters_match (self ):
525+ def test_parameters_match (self ) -> None :
526526 dl = self ._init_data ()
527527
528528 m1 = self ._init_model ()
@@ -721,7 +721,7 @@ def helper_test_noise_level(
721721
722722 @unittest .skip ("requires torchcsprng compatible with new pytorch versions" )
723723 @patch ("torch.normal" , MagicMock (return_value = torch .Tensor ([0.6 ])))
724- def test_generate_noise_in_secure_mode (self ):
724+ def test_generate_noise_in_secure_mode (self ) -> None :
725725 """
726726 Tests that the noise is added correctly in secure_mode,
727727 according to section 5.1 in https://arxiv.org/abs/2107.10138.
@@ -803,16 +803,16 @@ def _init_model(self):
803803
804804
805805class PrivacyEngineConvNetEmptyBatchTest (PrivacyEngineConvNetTest ):
806- def setUp (self ):
806+ def setUp (self ) -> None :
807807 super ().setUp ()
808808
809809 # This will trigger multiple empty batches with poisson sampling enabled
810810 self .BATCH_SIZE = 1
811811
812- def test_checkpoints (self ):
812+ def test_checkpoints (self ) -> None :
813813 pass
814814
815- def test_noise_level (self ):
815+ def test_noise_level (self ) -> None :
816816 pass
817817
818818
@@ -837,23 +837,23 @@ def _init_model(self):
837837
838838
839839class PrivacyEngineConvNetFrozenTestFunctorch (PrivacyEngineConvNetFrozenTest ):
840- def setUp (self ):
840+ def setUp (self ) -> None :
841841 super ().setUp ()
842842 self .GRAD_SAMPLE_MODE = "functorch"
843843
844844
845845class PrivacyEngineConvNetTestExpandedWeights (PrivacyEngineConvNetTest ):
846- def setUp (self ):
846+ def setUp (self ) -> None :
847847 super ().setUp ()
848848 self .GRAD_SAMPLE_MODE = "ew"
849849
850850 @unittest .skip ("Original p.grad is not available in ExpandedWeights" )
851- def test_sample_grad_aggregation (self ):
851+ def test_sample_grad_aggregation (self ) -> None :
852852 pass
853853
854854
855855class PrivacyEngineConvNetTestFunctorch (PrivacyEngineConvNetTest ):
856- def setUp (self ):
856+ def setUp (self ) -> None :
857857 super ().setUp ()
858858 self .GRAD_SAMPLE_MODE = "functorch"
859859
@@ -938,7 +938,7 @@ def _init_model(
938938
939939
940940class PrivacyEngineTextTestFunctorch (PrivacyEngineTextTest ):
941- def setUp (self ):
941+ def setUp (self ) -> None :
942942 super ().setUp ()
943943 self .GRAD_SAMPLE_MODE = "functorch"
944944
@@ -987,7 +987,7 @@ def _init_model(self):
987987
988988
989989class PrivacyEngineTiedWeightsTestFunctorch (PrivacyEngineTiedWeightsTest ):
990- def setUp (self ):
990+ def setUp (self ) -> None :
991991 super ().setUp ()
992992 self .GRAD_SAMPLE_MODE = "functorch"
993993
0 commit comments