diff --git a/configs/swin_transformer/SwinTransformer_base_patch4_window7_224.yaml b/configs/swin_transformer/SwinTransformer_base_patch4_window7_224.yaml index 63d79201..1cd747cb 100644 --- a/configs/swin_transformer/SwinTransformer_base_patch4_window7_224.yaml +++ b/configs/swin_transformer/SwinTransformer_base_patch4_window7_224.yaml @@ -2,6 +2,8 @@ epochs: 300 output_dir: output_dir seed: 0 +accumulate_grad_steps: 1 + model: name: SwinWrapper architecture: diff --git a/configs/swin_transformer/SwinTransformer_giant_patch4_window7_224.yaml b/configs/swin_transformer/SwinTransformer_giant_patch4_window7_224.yaml index b224dc0c..6aa02f98 100644 --- a/configs/swin_transformer/SwinTransformer_giant_patch4_window7_224.yaml +++ b/configs/swin_transformer/SwinTransformer_giant_patch4_window7_224.yaml @@ -14,15 +14,11 @@ AMP: "c_softmax_with_cross_entropy", "elementwise_div"] level: 'O1' -hybrid: - dp_degree: 8 - mp_degree: 1 - pp_degree: 1 - sharding: sharding_stage: 2 # 2 or 'dp' offload: False - accumulate_grad: False + +accumulate_grad_steps: 1 model: name: SwinWrapper diff --git a/configs/swin_transformer/SwinTransformer_huge_patch4_window7_224.yaml b/configs/swin_transformer/SwinTransformer_huge_patch4_window7_224.yaml index d424db59..ff7932bf 100644 --- a/configs/swin_transformer/SwinTransformer_huge_patch4_window7_224.yaml +++ b/configs/swin_transformer/SwinTransformer_huge_patch4_window7_224.yaml @@ -14,16 +14,11 @@ AMP: custom_black_list: ["reduce_mean", "reduce_sum", "c_softmax_with_cross_entropy", "elementwise_div"] level: 'O1' - -hybrid: - dp_degree: 8 - mp_degree: 1 - pp_degree: 1 - sharding: sharding_stage: 2 # 2 or 'dp' offload: False - accumulate_grad: False + +accumulate_grad_steps: 2 model: name: SwinWrapper diff --git a/configs/swin_transformer/SwinTransformer_small_patch4_window7_224.yaml b/configs/swin_transformer/SwinTransformer_small_patch4_window7_224.yaml index da9cd628..0a91d41e 100644 --- a/configs/swin_transformer/SwinTransformer_small_patch4_window7_224.yaml +++ b/configs/swin_transformer/SwinTransformer_small_patch4_window7_224.yaml @@ -1,7 +1,7 @@ epochs: 300 output_dir: output_dir seed: 0 - +accumulate_grad_steps: 1 model: name: SwinWrapper architecture: diff --git a/configs/swin_transformer/SwinTransformer_tiny_patch4_window7_224.yaml b/configs/swin_transformer/SwinTransformer_tiny_patch4_window7_224.yaml index 47f902db..27e74f30 100644 --- a/configs/swin_transformer/SwinTransformer_tiny_patch4_window7_224.yaml +++ b/configs/swin_transformer/SwinTransformer_tiny_patch4_window7_224.yaml @@ -2,6 +2,8 @@ epochs: 300 output_dir: output_dir seed: 0 +accumulate_grad_steps: 1 + model: name: SwinWrapper architecture: diff --git a/passl/engine/trainer.py b/passl/engine/trainer.py index ac1e0022..78520f59 100644 --- a/passl/engine/trainer.py +++ b/passl/engine/trainer.py @@ -113,6 +113,8 @@ def __init__(self, cfg): use_simclr_iters = cfg.get('use_simclr_iters', False) self.use_simclr_iters = use_simclr_iters self.epochs = cfg.get('epochs', None) + self.accumulate_grad_steps = cfg.get('accumulate_grad_steps', 1) + self.accumulate_grads = True if self.accumulate_grad_steps > 1 else False self.timestamp = cfg.timestamp self.logs = OrderedDict() # Ensure that the vdl log file can be closed normally @@ -147,7 +149,7 @@ def __init__(self, cfg): # distributed settings if dist.get_world_size() > 1: strategy = fleet.DistributedStrategy() - ## Hybrid Parallel Training + # Hybrid Parallel Training strategy.hybrid_configs = cfg.pop('hybrid') if 'hybrid' in cfg else {} fleet.init(is_collective=True, strategy=strategy) hcg = fleet.get_hybrid_communicate_group() @@ -157,7 +159,7 @@ def __init__(self, cfg): set_hyrbid_parallel_seed(seed, 0, mp_rank, pp_rank) # amp training - self.use_amp = cfg.get('use_amp', False) #if 'use_amp' in cfg else False + self.use_amp = cfg.get('use_amp', False) if self.use_amp: amp_cfg = cfg.pop('AMP') self.auto_cast = amp_cfg.pop('auto_cast') @@ -170,22 +172,24 @@ def __init__(self, cfg): self.sharding_strategies = cfg.get('sharding', False) if self.sharding_strategies: self.sharding_stage = self.sharding_strategies['sharding_stage'] - accumulate_grad = self.sharding_strategies['accumulate_grad'] offload = self.sharding_strategies['offload'] + # Note: Only support partition optimizer stages and gradient now! if self.sharding_stage == 2: + # Partition Optimizer self.optimizer = ShardingOptimizerStage2( params=self.model.parameters(), optim=self.optimizer, offload=offload) + # Partition Gradients self.model = ShardingStage2( self.model, self.optimizer, - accumulate_grads=accumulate_grad) + accumulate_grads=self.accumulate_grads) self.scaler = ShardingScaler(self.scaler) - elif self.sharding_stage == 'dp' and dist.get_world_size() > 1: - self.model = fleet.distributed_model(self.model) else: raise NotImplementedError() + elif dist.get_world_size() > 1: + self.model = fleet.distributed_model(self.model) @@ -374,7 +378,7 @@ def val(self, **kargs): outs[k] = AverageMeter(k, ':6.3f') outs[k].update(float(v), current_samples) - log_str = f'Validate Epoch [{self.current_epoch + 1}] ' + log_str = f'Validate Epoch [{self.current_epoch + 1}]' log_items = [] for name, val in outs.items(): if isinstance(val, AverageMeter): diff --git a/passl/hooks/optimizer_hook.py b/passl/hooks/optimizer_hook.py index 938da1ec..0e9cd3ec 100644 --- a/passl/hooks/optimizer_hook.py +++ b/passl/hooks/optimizer_hook.py @@ -20,28 +20,49 @@ class OptimizerHook(Hook): def __init__(self, priority=1): self.priority = priority - + def train_iter_end(self, trainer): - if 'Lars' in trainer.cfg['optimizer']['name']: - trainer.optimizer.clear_gradients() - else: - trainer.optimizer.clear_grad() + accumulate_steps = trainer.accumulate_grad_steps + if accumulate_steps > 1: + if trainer.current_iter % accumulate_steps == 0: + if 'Lars' in trainer.cfg['optimizer']['name']: + trainer.optimizer.clear_gradients() + else: + trainer.optimizer.clear_grad() - loss = 0 - loss = trainer.outputs['loss'] - - if trainer.use_amp: - scaled_loss = trainer.scaler.scale(loss) - scaled_loss.backward() - trainer.scaler.step(trainer.optimizer) - trainer.scaler.update() + loss = trainer.outputs['loss'] / accumulate_steps + if trainer.use_amp: + scaled_loss = trainer.scaler.scale(loss) + scaled_loss.backward() + trainer.scaler.step(trainer.optimizer) + trainer.scaler.update() + else: + loss.backward() + if 'lars' in trainer.optimizer.type: + trainer.optimizer.minimize(loss) + else: + trainer.optimizer.step() + else: + loss = trainer.outputs['loss'] / accumulate_steps + if trainer.use_amp: + scaled_loss = trainer.scaler.scale(loss) + scaled_loss.backward() + else: + loss.backward() else: - loss.backward() - if 'lars' in trainer.optimizer.type: - trainer.optimizer.minimize(loss) + loss = trainer.outputs['loss'] + if trainer.use_amp: + scaled_loss = trainer.scaler.scale(loss) + scaled_loss.backward() + trainer.scaler.step(trainer.optimizer) + trainer.scaler.update() else: - trainer.optimizer.step() + loss.backward() + if 'lars' in trainer.optimizer.type: + trainer.optimizer.minimize(loss) + else: + trainer.optimizer.step() if 'loss' not in trainer.outputs: trainer.outputs['loss'] = loss