generated from ashleve/lightning-hydra-template
-
Notifications
You must be signed in to change notification settings - Fork 0
Add callback that freezes specified module #141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
dxoigmn
wants to merge
232
commits into
main
Choose a base branch
from
freeze_callback
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 216 commits
Commits
Show all changes
232 commits
Select commit
Hold shift + click to select a range
f36d8c1
Remove NoAdversary
dxoigmn dcf7114
Remove NoAdversary from CIFAR10 adversarial training
dxoigmn 75886b6
Remove NoAdversary from RetinaNet model
dxoigmn 52aa94d
Fix COCO_TorchvisionFasterRCNN_Adv experiment
dxoigmn bbed2d6
Remove NoAdversary tests
dxoigmn f3f7b1b
First stab at treating adversary as LightningModule
dxoigmn 0e16460
style
dxoigmn e25af6f
bugfix
dxoigmn 36295b5
Integrate Perturber into LitPerturber
dxoigmn c68d9a8
Integrate objective to LitPerturber
dxoigmn 689da74
Cleanup use of objective function to compute gain
dxoigmn 3d70dde
bugfix
dxoigmn 2990cab
Make adversarial trainer silent
dxoigmn d5e17f7
style
dxoigmn b8f8761
Move threat model into LitPerturber
dxoigmn d39c5c1
Make attack callbacks plain PL callbacks
dxoigmn 5b7ee68
comment
dxoigmn f980315
Remove Perturber
dxoigmn 90fb9ab
comment
dxoigmn 00aefee
Better silence
dxoigmn 86bbc13
Integrate LitPerturber into Adversary
dxoigmn 7a12bf8
Uncombine Adversary into Adversary and LitPerturber
dxoigmn ad1872a
cleanup
dxoigmn 1aa4fc0
Move silence into utils
dxoigmn d09c5b8
bugfix
dxoigmn 78701a4
bugfix
dxoigmn 8143b5d
cleanup
dxoigmn 7bfc98f
Enable dependency injection on Adversary
dxoigmn eaf1607
Make dependency injection backwards compatible
dxoigmn bf5df50
Replace max_iters with trainer.limit_train_batches
dxoigmn 75fa073
comments
dxoigmn 4f74e53
Move perturbation creation into initializer
dxoigmn cde433c
Add Default projector
dxoigmn fdcfb5c
bugfix
dxoigmn e01093f
comment
dxoigmn 63d3e40
Move gradient modifier into PL hook
dxoigmn e71266b
Use on_train_epoch_start in favor of initialize_parameters
dxoigmn 7264b31
Make perturbation lazy
dxoigmn 4982851
Disable logger in attack
dxoigmn c5e5ddf
Revert initializer to d33658fac734274bbf87bce88a8b470afa1b3c71
dxoigmn 41357cf
cleanup
dxoigmn b5d116b
on_before_optimizer_step -> configure_gradient_clipping
dxoigmn 79dabd4
comments
dxoigmn a6f1d84
Disable attack progress bar
dxoigmn 2b9f403
comments
dxoigmn 0bd7c7c
comments
dxoigmn bac2bd4
comments
dxoigmn 079c15c
cleanup
dxoigmn 1a57cb6
comments
dxoigmn 7b96590
cleanup
dxoigmn 1357297
comments
dxoigmn 3e0d27c
comment
dxoigmn e59767e
Move LitPerturber into perturber.py
dxoigmn 0ef0a6c
bugfix
dxoigmn ab623e9
Make gradient modifiers in-place operations
dxoigmn b313209
cleanup
dxoigmn 3ef0876
Mark initializers __call__ as no_grad instead of using .data
dxoigmn 0c63042
Mark projectors __call__ as no_grad instead of using .data
dxoigmn 65397a4
Merge branch 'main' into adversary_as_lightningmodule
dxoigmn faefbee
Merge branch 'main' into adversary_as_lightningmodule
dxoigmn 6c2bbdc
Cleanup attack configs
dxoigmn 76da9b7
Merge branch 'main' into adversary_as_lightningmodule
dxoigmn 8890ffd
Fix merge error
dxoigmn 19cf58d
Fix merge error
dxoigmn 0b438df
comment
dxoigmn 3da4aa2
Merge branch 'main' into adversary_as_lightningmodule
dxoigmn ef84140
Merge branch 'main' into adversary_as_lightningmodule
mzweilin a01332a
Make Enforcer accept **kwargs.
mzweilin 333bf61
Update test_gradient.
mzweilin 8c84b43
LitPerturber -> Perturber
dxoigmn 0f866b9
Merge branch 'main' into adversary_as_lightningmodule
dxoigmn 970f53b
cleanup
dxoigmn addb2ab
Add _reset functionality
dxoigmn 38c8caf
Update tests and fix a bug
dxoigmn 020f99b
Remove batch tests
dxoigmn 7332f74
style
dxoigmn 260fd3d
Late bind trainer to input device
dxoigmn cd47df6
fix visualizer test
dxoigmn 0519b5a
bugfix
dxoigmn 4024852
bugfix
dxoigmn 1980d99
disable progress bar
dxoigmn 9809362
bugfix
dxoigmn bad9c10
Add loss to object detection outputs
dxoigmn c1cb07f
comment
dxoigmn b6afdd1
Make Adversary and Perturber tuple-aware
dxoigmn 8e7bc21
comment
dxoigmn c6bd29f
Update tests and fix bug
dxoigmn 739e4db
style
dxoigmn e581452
Remove BatchEnforcer and BatchComposer
dxoigmn 911998f
Revert to old gain functionality
dxoigmn 77f9828
Revert change to enforcer
dxoigmn d435d1e
fix perturber tests to take gain
dxoigmn aa08118
cleanup
dxoigmn 74f8fab
Place Trainer on same device as Perturber
dxoigmn f1c2a9d
Make composer, enforcer and projector tuple aware
dxoigmn c5020b5
fix projector tests
dxoigmn 8773d7f
Gracefully fail when input is a dict
dxoigmn d88c458
Make Projector batch aware
dxoigmn 13b8e32
Update projector tests
dxoigmn f0367c5
Merge branch 'main' into adversary_as_lightningmodule
dxoigmn 91f8904
Merge branch 'main' into make_projector_batch_aware
dxoigmn 9b545f4
Merge branch 'make_projector_batch_aware' into adversary_as_lightning…
dxoigmn eca9220
Merge branch 'main' into adversary_as_lightningmodule
dxoigmn c4d9f79
Merge branch 'main' into adversary_as_lightningmodule
dxoigmn 3d9e36a
Fix Adversary gradient test
dxoigmn 3f66c50
Make attacker a property
dxoigmn a5368a3
Fix configuration to construct Perturber
dxoigmn d0ae325
Remove MaskAdditive
dxoigmn c77de20
Revert "Remove MaskAdditive"
dxoigmn 98402d5
cleanup
dxoigmn a087bec
Undelete
dxoigmn 5122366
Merge Perturber into Adversary
dxoigmn b26ba8e
Update projector test to use proper spec
dxoigmn 1030e46
Abstract Perturber again
dxoigmn 13d50ca
bugfix
dxoigmn 9871142
Smarter configure_perturbation
dxoigmn 5207df7
Make perturbations proper parameters and cleanup Initializer
dxoigmn 427bcc5
Cleanup GradientModifier
dxoigmn 9b3cae7
Cleanup Composer
dxoigmn 25e839b
Cleanup Projector
dxoigmn 41924ed
Cleanup Enforcer
dxoigmn 0bb23b5
Cleanup callbacks
dxoigmn 9d75293
Cleanup NormalizedAdversaryAdapter
dxoigmn 429d1fb
Cleanup MartToArtAttackAdapter
dxoigmn 28aa453
Smarter detection of when we need to create perturbation
dxoigmn a169806
cleanup
dxoigmn 9e5ae70
style
dxoigmn 2ff5d0a
Remove GradientModifier from Perturber and cleanup
dxoigmn fec53c2
bugfix
dxoigmn d6605a1
Add GradientModifier to LitModular
dxoigmn b5d5442
Adversary consumes a OptimizerFactory
dxoigmn d2b4483
Better Composer and Projector type logic
dxoigmn 857d8f6
spelling
dxoigmn 9420a00
bugfix
dxoigmn 56014e6
Replace tuple with Iterable[torch.Tensor]
dxoigmn 280609a
Merge branch 'iterable_instead_of_tuple' into adversary_as_lightningm…
dxoigmn ef0cea1
cleanup
dxoigmn 1c47cc0
Fix tests
dxoigmn a009fd7
Merge branch 'iterable_instead_of_tuple' into adversary_as_lightningm…
dxoigmn 70cc36a
Cleanup
dxoigmn 8e632e0
Merge branch 'iterable_instead_of_tuple' into adversary_as_lightningm…
dxoigmn 53ee7f4
Make GradientModifier accept Iterable[torch.Tensor]
dxoigmn 89a6ce2
Merge branch 'iterable_instead_of_tuple' into adversary_as_lightningm…
dxoigmn 1e9526a
Revert changes to LitModular
dxoigmn f7345da
Revert Adversary consumes a OptimizerFactory
dxoigmn 5a49f82
Remove Callback base
dxoigmn 1fbae00
Fix tests
dxoigmn 8493da1
bugfix
dxoigmn 3068bc3
bugfix
dxoigmn 841f813
style
dxoigmn 14a362b
Add callback that freezes specified module
dxoigmn 8af2118
Remove AttackInEvalMode in favor of FreezeCallback
dxoigmn 557bcd2
bugfix
dxoigmn fff368d
Make attack callbacks normal callbacks
dxoigmn afc9f75
Merge branch 'cleanup_callbacks' into freeze_callback
dxoigmn 715a73d
Remove config
dxoigmn 557d878
Remove NoGradMode callback
dxoigmn 10106df
Fix annotations
dxoigmn 5cd900e
Make Perturber more flexible
dxoigmn 1953938
bugfix
dxoigmn c6dc5a4
Add GradientModifier and fix tests
dxoigmn 0055b10
Fix configs
dxoigmn e492e70
Get adversary tests from adversary_as_lightningmodule
dxoigmn c9c8429
Merge branch 'better_perturber' into adversary_as_lightningmodule
dxoigmn e8dadcb
Make attack callbacks normal callbacks
dxoigmn 27a3f80
Move attack optimizers to optimizers
dxoigmn b60057e
Merge branch 'better_optimizer' into adversary_as_lightningmodule
dxoigmn c48f410
bugfix
dxoigmn 754570d
style
dxoigmn b7c14b1
comment
dxoigmn 559bcfe
Merge branch 'main' into better_perturber
dxoigmn 6fd4943
fix test
dxoigmn c71cba6
style
dxoigmn 43f4520
Merge branch 'better_perturber' into better_optimizer
dxoigmn 58f91dc
style
dxoigmn bc03a87
Perturber is no longer a callback
dxoigmn b9af839
fix tests
dxoigmn a8d7201
fix tests
dxoigmn 7826e39
fix tests
dxoigmn fa3545b
fix tests
dxoigmn 57edc03
fix tests
dxoigmn 83938a9
fix tests
dxoigmn 0556e35
Merge branch 'better_perturber' into better_optimizer
dxoigmn f5ee114
fix tests
dxoigmn a631fa1
bugfix
dxoigmn 2115597
bugfix
dxoigmn 1679feb
Merge branch 'better_optimizer' into adversary_as_lightningmodule
dxoigmn 4edbdfa
fix tests
dxoigmn ef15c53
add missing tests
dxoigmn dcf7599
return tests to original tests
dxoigmn 46ed57f
style
dxoigmn 41eb387
Set optimizer to maximize in attacks
dxoigmn 5932223
Revert "Set optimizer to maximize in attacks"
dxoigmn 3bf7353
Adversary optimizer maximizes gain
dxoigmn bd438b5
Merge branch 'better_optimizer' into adversary_as_lightningmodule
dxoigmn 1489e7a
Merge branch 'adversary_as_lightningmodule' into freeze_callback
dxoigmn bee5221
comments
dxoigmn ea22015
remove yaml files
dxoigmn 08b3d00
Only set eval mode for BatchNorm and Dropout modules
dxoigmn a1f301f
Move Composer from Perturber and into Attacker
dxoigmn f6b367b
cleanup
dxoigmn 6a7673c
bugfix
dxoigmn c998de7
Merge branch 'better_perturber' into better_optimizer
dxoigmn 2d4366e
Merge branch 'better_optimizer' into adversary_as_lightningmodule
dxoigmn 98fdcc7
Add Composer to Adversary
dxoigmn 2db13a0
cleanup
dxoigmn c2f1f77
projector -> projector_
dxoigmn 48520cf
Hide adversarial parameters from model checkpoint. (#150)
mzweilin 8a1ca12
Merge branch 'better_perturber' into better_optimizer
mzweilin 3ff5796
Merge branch 'better_perturber' into adversary_as_lightningmodule
mzweilin 1c0eb27
Merge branch 'better_optimizer' into adversary_as_lightningmodule
mzweilin 773b2cb
Merge branch 'main' into adversary_as_lightningmodule
dxoigmn 1293a89
Fix merge error
dxoigmn e8d0852
Merge branch 'adversary_as_lightningmodule' into freeze_callback
dxoigmn 53dff65
Merge branch 'main' into freeze_callback
dxoigmn 8c55b47
Merge branch 'main' into freeze_callback
dxoigmn 99a7669
Use attrgetter
dxoigmn 588068c
Better implementation of ModelParamsNoGrad
dxoigmn 3832d22
Better implementation of AttackInEvalMode
dxoigmn a9348df
Log which params will have gradients disabled
dxoigmn 9c955df
Remove Freeze callback
dxoigmn d278aba
bugfix
dxoigmn 55a6161
comments
dxoigmn 830e765
comments
dxoigmn be8ae5d
comments
dxoigmn 04069b9
Even better AttackInEvalMode
dxoigmn 113d483
Fix type
dxoigmn 3dbdfd4
Even better ModelParamsNoGrad
dxoigmn 48577ad
more lenient
dxoigmn 3d04ef7
Update example modules to run in eval mode
dxoigmn 77c2350
Only log and run in fit stage
dxoigmn cc8a5d7
Merge branch 'main' into freeze_callback
dxoigmn File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,4 @@ | ||
| from .eval_mode import * | ||
| from .freeze import * | ||
| from .gradients import * | ||
| from .no_grad_mode import * | ||
| from .progress_bar import * | ||
| from .visualizer import * |
This file was deleted.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| # | ||
| # Copyright (C) 2022 Intel Corporation | ||
| # | ||
| # SPDX-License-Identifier: BSD-3-Clause | ||
| # | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import torch | ||
| from pytorch_lightning.callbacks import Callback | ||
| from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||
|
|
||
| from mart import utils | ||
|
|
||
| logger = utils.get_pylogger(__name__) | ||
|
|
||
| __all__ = ["FreezeModule"] | ||
|
|
||
|
|
||
| class FreezeModule(Callback): | ||
| def __init__( | ||
| self, | ||
| module="backbone", | ||
| ): | ||
| self.name = module | ||
|
|
||
| def setup(self, trainer, pl_module, stage): | ||
| # FIXME: Use DotDict? | ||
| module = getattr(pl_module.model, self.name, None) | ||
|
|
||
| if module is None or not isinstance(module, torch.nn.Module): | ||
| raise MisconfigurationException( | ||
| f"The LightningModule should have a nn.Module `{self.name}` attribute" | ||
| ) | ||
|
|
||
| for name, param in module.named_parameters(): | ||
| logger.debug(f"Disabling gradient for {name}") | ||
| param.requires_grad_(False) | ||
|
|
||
| for name, module in module.named_modules(): | ||
| module_kind = module.__class__.__name__ | ||
| if "BatchNorm" in module_kind: | ||
dxoigmn marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| logger.info(f"Setting eval mode for {name} ({module_kind})") | ||
|
|
||
| def on_train_epoch_start(self, trainer, pl_module): | ||
| # FIXME: Use DotDict? | ||
| module = getattr(pl_module.model, self.name, None) | ||
|
|
||
| if module is None or not isinstance(module, torch.nn.Module): | ||
| raise MisconfigurationException( | ||
| f"The LightningModule should have a nn.Module `{self.name}` attribute" | ||
| ) | ||
|
|
||
| for name, module in module.named_modules(): | ||
| module_kind = module.__class__.__name__ | ||
| if "BatchNorm" in module_kind or "Dropout" in module_kind: | ||
| module.eval() | ||
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| freeze: | ||
| _target_: mart.callbacks.FreezeModule | ||
| module: ??? |
This file was deleted.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.