From 2f7612ad8078c152e2592dc5066f37f0031d9913 Mon Sep 17 00:00:00 2001 From: firestonelib Date: Tue, 26 Oct 2021 19:15:27 +0800 Subject: [PATCH 1/5] add swin transformer --- ...nTransformer_small_patch4_window7_224.yaml | 61 ++ data | 1 + outputs/log.txt | 3 + passl/modeling/architectures/SwinWrapper.py | 76 ++ passl/modeling/architectures/__init__.py | 1 + passl/modeling/backbones/__init__.py | 1 + passl/modeling/backbones/swin_transformer.py | 739 ++++++++++++++++++ passl/modeling/heads/__init__.py | 1 + passl/modeling/heads/swin_transformer_head.py | 80 ++ passl/solver/builder.py | 2 +- 10 files changed, 964 insertions(+), 1 deletion(-) create mode 100644 configs/swin_transformer/SwinTransformer_small_patch4_window7_224.yaml create mode 120000 data create mode 100644 outputs/log.txt create mode 100644 passl/modeling/architectures/SwinWrapper.py create mode 100644 passl/modeling/backbones/swin_transformer.py create mode 100644 passl/modeling/heads/swin_transformer_head.py diff --git a/configs/swin_transformer/SwinTransformer_small_patch4_window7_224.yaml b/configs/swin_transformer/SwinTransformer_small_patch4_window7_224.yaml new file mode 100644 index 00000000..28a77a45 --- /dev/null +++ b/configs/swin_transformer/SwinTransformer_small_patch4_window7_224.yaml @@ -0,0 +1,61 @@ +epochs: 300 +output_dir: output_dir + +model: + name: SwinWrapper + architecture: + name: SwinTransformer + embed_dim: 96 + depths: [2, 2, 18, 2] + num_heads: [3, 6, 12, 24] + window_size: 7 + head: + name: SwinTransformerClsHead + with_avg_pool: True + num_classes: 1000 + in_channels: 768 + +dataloader: + train: + num_workers: 0 + sampler: + batch_size: 128 + shuffle: true + drop_last: True + dataset: + name: ImageNet + dataroot: data/ILSVRC2012/train/ + return_label: True + transforms: + - name: ToRGB + - name: RandomResizedCrop + size: 224 + scale: [0.75, 1.] + ratio: [1., 1.] + interpolation: 'bicubic' + - name: Transpose + - name: Normalize + mean: [127.5, 127.5, 127.5] + std: [127.5, 127.5, 127.5] + +lr_scheduler: + name: CosineWarmup + learning_rate: 0.003 + T_max: 93835 + warmup_steps: 10000 + start_lr: 0.00003 + end_lr: 0.003 + +optimizer: + name: AdamW + beta1: 0.9 + beta2: 0.999 + weight_decay: 0.3 + grad_clip: + name: global_norm + value: 1.0 + + +log_config: + name: LogHook + interval: 10 diff --git a/data b/data new file mode 120000 index 00000000..e2fbebc7 --- /dev/null +++ b/data @@ -0,0 +1 @@ +../../PASSL/data \ No newline at end of file diff --git a/outputs/log.txt b/outputs/log.txt new file mode 100644 index 00000000..50c488fd --- /dev/null +++ b/outputs/log.txt @@ -0,0 +1,3 @@ +[10/20 12:22:21] passl INFO: Configs: {'epochs': 300, 'output_dir': 'outputs', 'model': {'name': 'ViTWrapper', 'architecture': {'name': 'VisionTransformer', 'img_size': 384, 'patch_size': 32, 'width': 768, 'depth': 12, 'num_heads': 12, 'mlp_ratio': 4, 'qkv_bias': True}, 'head': {'name': 'VisionTransformerClsHead', 'num_classes': 1000, 'in_channels': 768}}, 'dataloader': {'train': {'num_workers': 0, 'sampler': {'batch_size': 128, 'shuffle': True, 'drop_last': True}, 'dataset': {'name': 'ImageNet', 'dataroot': 'data/ILSVRC2012/train/', 'return_label': True, 'transforms': [{'name': 'ToRGB'}, {'name': 'RandomResizedCrop', 'size': 384, 'scale': [0.75, 1.0], 'ratio': [1.0, 1.0], 'interpolation': 'bicubic'}, {'name': 'Transpose'}, {'name': 'Normalize', 'mean': [127.5, 127.5, 127.5], 'std': [127.5, 127.5, 127.5]}]}}}, 'lr_scheduler': {'name': 'CosineWarmup', 'learning_rate': 12.28, 'T_max': 93835, 'warmup_steps': 10000, 'start_lr': 0.01228, 'end_lr': 12.28}, 'optimizer': {'name': 'AdamW', 'beta1': 0.9, 'beta2': 0.999, 'weight_decay': 0.3}, 'log_config': {'name': 'LogHook', 'interval': 10}, 'is_train': True, 'timestamp': '-2021-10-20-12-22'} +[10/20 12:22:50] passl.engine.trainer INFO: Epoch [1/300][10/365] lr: 2.332e-02, eta: 3 days, 10:06:49, time: 2.700, data_time: 2.279, loss 1.3538e+01 (9.1954e+00), acc1 0.000 ( 2.344), acc5 15.625 (12.500) +[10/20 12:23:16] passl.engine.trainer INFO: Epoch [1/300][20/365] lr: 3.559e-02, eta: 3 days, 9:12:34, time: 2.670, data_time: 2.280, loss 1.5001e+01 (1.2273e+01), acc1 1.562 ( 2.508), acc5 7.812 (13.035) diff --git a/passl/modeling/architectures/SwinWrapper.py b/passl/modeling/architectures/SwinWrapper.py new file mode 100644 index 00000000..c17e9004 --- /dev/null +++ b/passl/modeling/architectures/SwinWrapper.py @@ -0,0 +1,76 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +import paddle.distributed as dist + +from ..backbones import build_backbone +from ..heads import build_head +from .builder import MODELS + + +@MODELS.register() +class SwinWrapper(nn.Layer): + def __init__(self, + architecture=None, + head=None + ): + """A wrapper for a ViT model as specified in the paper. + + Args: + architecture (dict): A dictionary containing the ViT instantiation parameters. + """ + super().__init__() + + self.backbone = build_backbone(architecture) + self.automatic_optimization = False + self.head = build_head(head) + + def backbone_forward(self, x): + x = self.backbone(x) + return x + + def train_iter(self, *inputs, **kwargs): + img, label = inputs + x = self.backbone_forward(img) + if isinstance(x, tuple): + x = x[-1] + _, cls_token = x + outs = self.head(cls_token) + loss_inputs = (outs, label) + outputs = self.head.loss(*loss_inputs) + return outputs + + + def forward(self, *inputs, mode='train', **kwargs): + if mode == 'train': + return self.train_iter(*inputs, **kwargs) + elif mode == 'test': + return self.test_iter(*inputs, **kwargs) + elif mode == 'extract': + return self.backbone(*inputs) + else: + raise Exception("No such mode: {}".format(mode)) + + + + def validation_step(self, val_batch, idx): + image, text = val_batch + image_logits, text_logits = self.forward(image, text) + ground_truth = paddle.arange(len(image_logits)) + loss = (self.image_loss(image_logits, ground_truth) + self.text_loss(text_logits, ground_truth)).div(2) + self.log('val_loss', loss) diff --git a/passl/modeling/architectures/__init__.py b/passl/modeling/architectures/__init__.py index 7754540c..1e33c5a3 100644 --- a/passl/modeling/architectures/__init__.py +++ b/passl/modeling/architectures/__init__.py @@ -20,4 +20,5 @@ from .simclr import SimCLR from .byol_clas import ByolClassification from .ViTWrapper import ViTWrapper +from .SwinWrapper import SwinWrapper from .builder import build_model diff --git a/passl/modeling/backbones/__init__.py b/passl/modeling/backbones/__init__.py index 543359a5..6163f59f 100644 --- a/passl/modeling/backbones/__init__.py +++ b/passl/modeling/backbones/__init__.py @@ -4,3 +4,4 @@ from .resnetcifar import ResNet from .resnetsimclr import ResNetsimclr from .vision_transformer import VisionTransformer +from .swin_transformer import SwinTransformer diff --git a/passl/modeling/backbones/swin_transformer.py b/passl/modeling/backbones/swin_transformer.py new file mode 100644 index 00000000..ba733439 --- /dev/null +++ b/passl/modeling/backbones/swin_transformer.py @@ -0,0 +1,739 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +import paddle.nn as nn +from .builder import BACKBONES +from paddle.nn.layer.transformer import _convert_attention_mask +from paddle.nn.initializer import TruncatedNormal, Constant, Normal + +from .base_transformer import QuickGELU + +__all__ = ["VisionTransformer", "ViT_small_patch16_224"] + +trunc_normal_ = TruncatedNormal(std=.02) +zeros_ = Constant(value=0.) +ones_ = Constant(value=1.) + + +class Mlp(nn.Layer): + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.reshape( + [B, H // window_size, window_size, W // window_size, window_size, C]) + windows = x.transpose([0, 1, 3, 2, 4, 5]).reshape( + [-1, window_size, window_size, C]) + return windows + + +def window_reverse(windows, window_size, H, W, C): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.reshape([B, H // window_size, W // window_size, window_size, window_size, -1]) + x = x.transpose([0, 1, 3, 2, 4, 5]).reshape([B, H, W, -1]) + return x + + +class WindowAttention(nn.Layer): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, + dim, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0., + proj_drop=0.): + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = self.create_parameter( + shape=((2 * window_size[0] - 1) * (2 * window_size[1] - 1), + num_heads), # 2*Wh-1 * 2*Ww-1, nH + default_initializer=zeros_) + self.add_parameter("relative_position_bias_table", + self.relative_position_bias_table) + + # get pair-wise relative position index for each token inside the window + coords_h = paddle.arange(self.window_size[0]) + coords_w = paddle.arange(self.window_size[1]) + coords = paddle.stack(paddle.meshgrid( + [coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = paddle.flatten(coords, 1) # 2, Wh*Ww + + coords_flatten_1 = coords_flatten.unsqueeze(axis=2) + coords_flatten_2 = coords_flatten.unsqueeze(axis=1) + relative_coords = coords_flatten_1 - coords_flatten_2 + + relative_coords = relative_coords.transpose( + [1, 2, 0]) # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[ + 0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", + relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table) + self.softmax = nn.Softmax(axis=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape((B_, N, 3, self.num_heads, C // + self.num_heads)).transpose((2, 0, 3, 1, 4)) + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = (q.matmul(k.transpose((0, 1, 3, 2)))) + + index = self.relative_position_index.reshape([-1]) + + relative_position_bias = paddle.index_select( + self.relative_position_bias_table, index) + relative_position_bias = relative_position_bias.reshape([ + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], -1 + ]) # Wh*Ww,Wh*Ww,nH + + relative_position_bias = relative_position_bias.transpose([2, 0, 1]) # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.reshape([B_ // nW, nW, self.num_heads, N, N + ]) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.reshape([-1, self.num_heads, N, N]) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn.matmul(v)).transpose((0, 2, 1, 3)).reshape([B_, N, C]) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self): + return "dim={}, window_size={}, num_heads={}".format( + self.dim, self.window_size, self.num_heads) + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Layer): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + def __init__(self, + dim, + input_resolution, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + epsilon=1e-5): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + self.norm1 = eval(norm_layer)(dim) + self.attn = WindowAttention( + dim, + window_size=(self.window_size, self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity() + self.norm2 = eval(norm_layer)(dim, epsilon=epsilon) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = paddle.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition( + img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.reshape( + [-1, self.window_size * self.window_size]) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + + huns = -100.0 * paddle.ones_like(attn_mask) + attn_mask = huns * (attn_mask != 0).astype("float32") + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def forward(self, x): + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.reshape([B, H, W, C]) + + # cyclic shift + if self.shift_size > 0: + shifted_x = paddle.roll( + x, shifts=(-self.shift_size, -self.shift_size), axis=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition( + shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.reshape( + [-1, self.window_size * self.window_size, + C]) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn( + x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.reshape( + [-1, self.window_size, self.window_size, C]) + shifted_x = window_reverse(attn_windows, self.window_size, H, W, + C) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = paddle.roll( + shifted_x, + shifts=(self.shift_size, self.shift_size), + axis=(1, 2)) + else: + x = shifted_x + x = x.reshape([B, H * W, C]) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self): + return "dim={}, input_resolution={}, num_heads={}, window_size={}, shift_size={}, mlp_ratio={}".format( + self.dim, self.input_resolution, self.num_heads, self.window_size, + self.shift_size, self.mlp_ratio) + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Layer): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias_attr=False) + self.norm = eval(norm_layer)(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, "x size ({}*{}) are not even.".format( + H, W) + + x = x.reshape([B, H, W, C]) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = paddle.concat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.reshape([B, (H // 2) * (W // 2), 4 * C]) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return "input_resolution={}, dim={}".format(self.input_resolution, + self.dim) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Layer): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + dim, + input_resolution, + depth, + num_heads, + window_size, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer='nn.LayerNorm', + downsample=None, + use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.LayerList([ + SwinTransformerBlock(dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] + if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) for i in range(depth) + ]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample( + input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + #if self.use_checkpoint: + # x = checkpoint.checkpoint(blk, x) + #else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self): + return "dim={}, input_resolution={}, depth={}".format( + self.dim, self.input_resolution, self.depth) + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + +def drop_path(x, drop_prob=0., training=False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... + """ + if drop_prob == 0. or not training: + return x + keep_prob = paddle.to_tensor(1 - drop_prob) + shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1) + random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype) + random_tensor = paddle.floor(random_tensor) # binarize + output = x.divide(keep_prob) * random_tensor + return output + + +class DropPath(nn.Layer): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class Identity(nn.Layer): + def __init__(self): + super(Identity, self).__init__() + + def forward(self, input): + return input + + +class PatchEmbed(nn.Layer): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, + img_size=224, + patch_size=4, + in_chans=3, + embed_dim=96, + norm_layer=None): + super().__init__() + img_size = (img_size, img_size) + patch_size = (patch_size, patch_size) + patches_resolution = [ + img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2D( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias_attr=False) + if norm_layer is not None: + self.norm = eval(norm_layer)(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + "Input image size ({}*{}) doesn't match model ({}*{}).".format( + H, W, self.img_size[0], self.img_size[1]) + + x = self.proj(x).flatten(2).transpose([0, 2, 1]) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * ( + self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + + +@BACKBONES.register() +class SwinTransformer(nn.Layer): + """ Swin Transformer + A PASSL implementation of `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, + img_size=224, + patch_size=4, + in_chans=3, + num_classes=1000, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.1, + norm_layer='nn.LayerNorm', + ape=False, + patch_norm=True, + use_checkpoint=False, + epsilon=1e-5, + **kwargs): + super(SwinTransformer, self).__init__() + + self.num_classes = num_classes + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = self.create_parameter( + shape=(1, num_patches, embed_dim), default_initializer=zeros_) + self.add_parameter("absolute_pos_embed", self.absolute_pos_embed) + trunc_normal_(self.absolute_pos_embed) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = np.linspace(0, drop_path_rate, + sum(depths)).tolist() # stochastic depth decay rule + + # build layers + self.layers = nn.LayerList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim = int(embed_dim * 2**i_layer), + input_resolution=(patches_resolution[0] // (2**i_layer), + patches_resolution[1] // (2**i_layer)), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging + if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint) + self.layers.append(layer) + + self.norm = eval(norm_layer)(self.num_features, epsilon=epsilon) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + zeros_(m.bias) + elif isinstance(m, nn.LayerNorm): + zeros_(m.bias) + ones_(m.weight) + + def forward_features(self, x): + x = self.patch_embed(x) + if self.ape: + x += self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x) + + x = self.norm(x).transpose([0, 2, 1]) + + return x + + def forward(self, x): + x = self.forward_features(x) + return x diff --git a/passl/modeling/heads/__init__.py b/passl/modeling/heads/__init__.py index 02f04ee4..ba6ee7ae 100644 --- a/passl/modeling/heads/__init__.py +++ b/passl/modeling/heads/__init__.py @@ -20,3 +20,4 @@ from .builder import build_head from .simclr_contrastive_head import SimCLRContrastiveHead from .vision_transformer_head import VisionTransformerClsHead +from .swin_transformer_head import SwinTransformerClsHead diff --git a/passl/modeling/heads/swin_transformer_head.py b/passl/modeling/heads/swin_transformer_head.py new file mode 100644 index 00000000..dc0e8acb --- /dev/null +++ b/passl/modeling/heads/swin_transformer_head.py @@ -0,0 +1,80 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn + +from ...modules.init import reset_parameters, normal_init +from .builder import HEADS +from .clas_head import ClasHead + + +@HEADS.register() +class SwinTransformerClsHead(ClasHead): + """Vision Transformer classifier head. + + Args: + with_avg_pool (bool): Use average pooling or not. Default: False. + in_channels (int): Number of channels in the input feature map. + num_classes (int): Number of categories excluding the background + category. + """ + def __init__(self, + with_avg_pool=False, + in_channels=2048, + num_classes=1000): + super(SwinTransformerClsHead, self).__init__() + self.with_avg_pool = with_avg_pool + self.in_channels = in_channels + self.num_classes = num_classes + self.criterion = nn.CrossEntropyLoss() + self.fc_cls = nn.Linear(in_channels, num_classes) + + normal_init(self.fc_cls, mean=0.0, std=0.01, bias=0.0) + + def forward(self, x): + if self.with_avg_pool: + assert x.dim() == 4, \ + "Tensor must has 4 dims, got: {}".format(x.dim()) + x = self.avg_pool(x) + x = paddle.flatten(x, 1) + cls_score = self.fc_cls(x) + return cls_score + + def loss(self, cls_score, labels): + losses = dict() + + losses['loss'] = self.criterion(cls_score, labels) + losses['acc1'], losses['acc5'] = accuracy(cls_score, + labels, + topk=(1, 5)) + return losses + + +def accuracy(output, target, topk=(1, )): + """Computes the accuracy over the k top predictions for the specified values of k""" + with paddle.no_grad(): + maxk = max(topk) + batch_size = target.shape[0] + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = paddle.cast(pred == target.reshape([1, -1]).expand_as(pred), + 'float32') + + res = [] + for k in topk: + correct_k = correct[:k].reshape([-1]).sum(0, keepdim=True) + res.append(correct_k * 100.0 / batch_size) + return res diff --git a/passl/solver/builder.py b/passl/solver/builder.py index 3021d6a0..06d69a85 100644 --- a/passl/solver/builder.py +++ b/passl/solver/builder.py @@ -71,7 +71,7 @@ def build_optimizer(cfg, lr_scheduler, parameters=None): grad_clip_cfg = cfg_.pop('grad_clip') if grad_clip_cfg['name'] == 'global_norm': clip_norm = grad_clip_cfg['value'] - grad_clip = paddle.nn.clip.GradientClipByGlobalNorm(clip_norm=clip_norm) + grad_clip = paddle.nn.clip.ClipGradByGlobalNorm(clip_norm=clip_norm) else: grad_clip = None if name == 'LarsMomentumOptimizer': From b5c2286bd950008a81ca8db3739f526920983e23 Mon Sep 17 00:00:00 2001 From: firestonelib Date: Tue, 26 Oct 2021 19:16:47 +0800 Subject: [PATCH 2/5] add swin transformer --- data | 1 - 1 file changed, 1 deletion(-) delete mode 120000 data diff --git a/data b/data deleted file mode 120000 index e2fbebc7..00000000 --- a/data +++ /dev/null @@ -1 +0,0 @@ -../../PASSL/data \ No newline at end of file From ba20e0496e25b2c770051329281e91cd57fdfadb Mon Sep 17 00:00:00 2001 From: firestonelib Date: Tue, 26 Oct 2021 19:17:21 +0800 Subject: [PATCH 3/5] add swin transformer --- outputs/log.txt | 3 --- 1 file changed, 3 deletions(-) delete mode 100644 outputs/log.txt diff --git a/outputs/log.txt b/outputs/log.txt deleted file mode 100644 index 50c488fd..00000000 --- a/outputs/log.txt +++ /dev/null @@ -1,3 +0,0 @@ -[10/20 12:22:21] passl INFO: Configs: {'epochs': 300, 'output_dir': 'outputs', 'model': {'name': 'ViTWrapper', 'architecture': {'name': 'VisionTransformer', 'img_size': 384, 'patch_size': 32, 'width': 768, 'depth': 12, 'num_heads': 12, 'mlp_ratio': 4, 'qkv_bias': True}, 'head': {'name': 'VisionTransformerClsHead', 'num_classes': 1000, 'in_channels': 768}}, 'dataloader': {'train': {'num_workers': 0, 'sampler': {'batch_size': 128, 'shuffle': True, 'drop_last': True}, 'dataset': {'name': 'ImageNet', 'dataroot': 'data/ILSVRC2012/train/', 'return_label': True, 'transforms': [{'name': 'ToRGB'}, {'name': 'RandomResizedCrop', 'size': 384, 'scale': [0.75, 1.0], 'ratio': [1.0, 1.0], 'interpolation': 'bicubic'}, {'name': 'Transpose'}, {'name': 'Normalize', 'mean': [127.5, 127.5, 127.5], 'std': [127.5, 127.5, 127.5]}]}}}, 'lr_scheduler': {'name': 'CosineWarmup', 'learning_rate': 12.28, 'T_max': 93835, 'warmup_steps': 10000, 'start_lr': 0.01228, 'end_lr': 12.28}, 'optimizer': {'name': 'AdamW', 'beta1': 0.9, 'beta2': 0.999, 'weight_decay': 0.3}, 'log_config': {'name': 'LogHook', 'interval': 10}, 'is_train': True, 'timestamp': '-2021-10-20-12-22'} -[10/20 12:22:50] passl.engine.trainer INFO: Epoch [1/300][10/365] lr: 2.332e-02, eta: 3 days, 10:06:49, time: 2.700, data_time: 2.279, loss 1.3538e+01 (9.1954e+00), acc1 0.000 ( 2.344), acc5 15.625 (12.500) -[10/20 12:23:16] passl.engine.trainer INFO: Epoch [1/300][20/365] lr: 3.559e-02, eta: 3 days, 9:12:34, time: 2.670, data_time: 2.280, loss 1.5001e+01 (1.2273e+01), acc1 1.562 ( 2.508), acc5 7.812 (13.035) From bbaec05b5fe1dc23423b8071f6bc612fa23d123b Mon Sep 17 00:00:00 2001 From: firestonelib Date: Tue, 28 Dec 2021 13:48:43 +0800 Subject: [PATCH 4/5] add accumulate gradients --- ...inTransformer_base_patch4_window7_224.yaml | 2 + ...nTransformer_giant_patch4_window7_224.yaml | 8 +- ...inTransformer_huge_patch4_window7_224.yaml | 9 +- ...nTransformer_small_patch4_window7_224.yaml | 2 +- ...inTransformer_tiny_patch4_window7_224.yaml | 2 + passl/engine/trainer.py | 18 +-- passl/hooks/optimizer_hook.py | 111 ++++++++++-------- 7 files changed, 84 insertions(+), 68 deletions(-) diff --git a/configs/swin_transformer/SwinTransformer_base_patch4_window7_224.yaml b/configs/swin_transformer/SwinTransformer_base_patch4_window7_224.yaml index 63d79201..1cd747cb 100644 --- a/configs/swin_transformer/SwinTransformer_base_patch4_window7_224.yaml +++ b/configs/swin_transformer/SwinTransformer_base_patch4_window7_224.yaml @@ -2,6 +2,8 @@ epochs: 300 output_dir: output_dir seed: 0 +accumulate_grad_steps: 1 + model: name: SwinWrapper architecture: diff --git a/configs/swin_transformer/SwinTransformer_giant_patch4_window7_224.yaml b/configs/swin_transformer/SwinTransformer_giant_patch4_window7_224.yaml index b224dc0c..6aa02f98 100644 --- a/configs/swin_transformer/SwinTransformer_giant_patch4_window7_224.yaml +++ b/configs/swin_transformer/SwinTransformer_giant_patch4_window7_224.yaml @@ -14,15 +14,11 @@ AMP: "c_softmax_with_cross_entropy", "elementwise_div"] level: 'O1' -hybrid: - dp_degree: 8 - mp_degree: 1 - pp_degree: 1 - sharding: sharding_stage: 2 # 2 or 'dp' offload: False - accumulate_grad: False + +accumulate_grad_steps: 1 model: name: SwinWrapper diff --git a/configs/swin_transformer/SwinTransformer_huge_patch4_window7_224.yaml b/configs/swin_transformer/SwinTransformer_huge_patch4_window7_224.yaml index d424db59..ff7932bf 100644 --- a/configs/swin_transformer/SwinTransformer_huge_patch4_window7_224.yaml +++ b/configs/swin_transformer/SwinTransformer_huge_patch4_window7_224.yaml @@ -14,16 +14,11 @@ AMP: custom_black_list: ["reduce_mean", "reduce_sum", "c_softmax_with_cross_entropy", "elementwise_div"] level: 'O1' - -hybrid: - dp_degree: 8 - mp_degree: 1 - pp_degree: 1 - sharding: sharding_stage: 2 # 2 or 'dp' offload: False - accumulate_grad: False + +accumulate_grad_steps: 2 model: name: SwinWrapper diff --git a/configs/swin_transformer/SwinTransformer_small_patch4_window7_224.yaml b/configs/swin_transformer/SwinTransformer_small_patch4_window7_224.yaml index da9cd628..0a91d41e 100644 --- a/configs/swin_transformer/SwinTransformer_small_patch4_window7_224.yaml +++ b/configs/swin_transformer/SwinTransformer_small_patch4_window7_224.yaml @@ -1,7 +1,7 @@ epochs: 300 output_dir: output_dir seed: 0 - +accumulate_grad_steps: 1 model: name: SwinWrapper architecture: diff --git a/configs/swin_transformer/SwinTransformer_tiny_patch4_window7_224.yaml b/configs/swin_transformer/SwinTransformer_tiny_patch4_window7_224.yaml index 47f902db..27e74f30 100644 --- a/configs/swin_transformer/SwinTransformer_tiny_patch4_window7_224.yaml +++ b/configs/swin_transformer/SwinTransformer_tiny_patch4_window7_224.yaml @@ -2,6 +2,8 @@ epochs: 300 output_dir: output_dir seed: 0 +accumulate_grad_steps: 1 + model: name: SwinWrapper architecture: diff --git a/passl/engine/trainer.py b/passl/engine/trainer.py index ac1e0022..78520f59 100644 --- a/passl/engine/trainer.py +++ b/passl/engine/trainer.py @@ -113,6 +113,8 @@ def __init__(self, cfg): use_simclr_iters = cfg.get('use_simclr_iters', False) self.use_simclr_iters = use_simclr_iters self.epochs = cfg.get('epochs', None) + self.accumulate_grad_steps = cfg.get('accumulate_grad_steps', 1) + self.accumulate_grads = True if self.accumulate_grad_steps > 1 else False self.timestamp = cfg.timestamp self.logs = OrderedDict() # Ensure that the vdl log file can be closed normally @@ -147,7 +149,7 @@ def __init__(self, cfg): # distributed settings if dist.get_world_size() > 1: strategy = fleet.DistributedStrategy() - ## Hybrid Parallel Training + # Hybrid Parallel Training strategy.hybrid_configs = cfg.pop('hybrid') if 'hybrid' in cfg else {} fleet.init(is_collective=True, strategy=strategy) hcg = fleet.get_hybrid_communicate_group() @@ -157,7 +159,7 @@ def __init__(self, cfg): set_hyrbid_parallel_seed(seed, 0, mp_rank, pp_rank) # amp training - self.use_amp = cfg.get('use_amp', False) #if 'use_amp' in cfg else False + self.use_amp = cfg.get('use_amp', False) if self.use_amp: amp_cfg = cfg.pop('AMP') self.auto_cast = amp_cfg.pop('auto_cast') @@ -170,22 +172,24 @@ def __init__(self, cfg): self.sharding_strategies = cfg.get('sharding', False) if self.sharding_strategies: self.sharding_stage = self.sharding_strategies['sharding_stage'] - accumulate_grad = self.sharding_strategies['accumulate_grad'] offload = self.sharding_strategies['offload'] + # Note: Only support partition optimizer stages and gradient now! if self.sharding_stage == 2: + # Partition Optimizer self.optimizer = ShardingOptimizerStage2( params=self.model.parameters(), optim=self.optimizer, offload=offload) + # Partition Gradients self.model = ShardingStage2( self.model, self.optimizer, - accumulate_grads=accumulate_grad) + accumulate_grads=self.accumulate_grads) self.scaler = ShardingScaler(self.scaler) - elif self.sharding_stage == 'dp' and dist.get_world_size() > 1: - self.model = fleet.distributed_model(self.model) else: raise NotImplementedError() + elif dist.get_world_size() > 1: + self.model = fleet.distributed_model(self.model) @@ -374,7 +378,7 @@ def val(self, **kargs): outs[k] = AverageMeter(k, ':6.3f') outs[k].update(float(v), current_samples) - log_str = f'Validate Epoch [{self.current_epoch + 1}] ' + log_str = f'Validate Epoch [{self.current_epoch + 1}]' log_items = [] for name, val in outs.items(): if isinstance(val, AverageMeter): diff --git a/passl/hooks/optimizer_hook.py b/passl/hooks/optimizer_hook.py index 938da1ec..1728638b 100644 --- a/passl/hooks/optimizer_hook.py +++ b/passl/hooks/optimizer_hook.py @@ -1,47 +1,64 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .hook import Hook -from .builder import HOOKS - - -@HOOKS.register() -class OptimizerHook(Hook): - def __init__(self, priority=1): - self.priority = priority - - def train_iter_end(self, trainer): - if 'Lars' in trainer.cfg['optimizer']['name']: - trainer.optimizer.clear_gradients() - else: - trainer.optimizer.clear_grad() - - loss = 0 - loss = trainer.outputs['loss'] - - if trainer.use_amp: - scaled_loss = trainer.scaler.scale(loss) - scaled_loss.backward() - trainer.scaler.step(trainer.optimizer) - trainer.scaler.update() - - else: - loss.backward() - if 'lars' in trainer.optimizer.type: - trainer.optimizer.minimize(loss) - else: - trainer.optimizer.step() - - if 'loss' not in trainer.outputs: - trainer.outputs['loss'] = loss +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .hook import Hook +from .builder import HOOKS + + +@HOOKS.register() +class OptimizerHook(Hook): + def __init__(self, priority=1): + self.priority = priority + + + def train_iter_end(self, trainer): + accumulate_steps = trainer.accumulate_grad_steps + if accumulate_steps > 1: + if trainer.current_iter % accumulate_steps == 0: + if 'Lars' in trainer.cfg['optimizer']['name']: + trainer.optimizer.clear_gradients() + else: + trainer.optimizer.clear_grad() + + loss = 0 + loss = trainer.outputs['loss'] / accumulate_steps + if trainer.use_amp: + scaled_loss = trainer.scaler.scale(loss) + scaled_loss.backward() + trainer.scaler.step(trainer.optimizer) + trainer.scaler.update() + + else: + loss.backward() + if 'lars' in trainer.optimizer.type: + trainer.optimizer.minimize(loss) + else: + trainer.optimizer.step() + else: + loss = 0 + loss = trainer.outputs['loss'] + if trainer.use_amp: + scaled_loss = trainer.scaler.scale(loss) + scaled_loss.backward() + trainer.scaler.step(trainer.optimizer) + trainer.scaler.update() + else: + loss.backward() + if 'lars' in trainer.optimizer.type: + trainer.optimizer.minimize(loss) + else: + trainer.optimizer.step() + + if 'loss' not in trainer.outputs: + trainer.outputs['loss'] = loss From 670e94696f17ab7908effa414eeebf815a6535d2 Mon Sep 17 00:00:00 2001 From: firestonelib Date: Tue, 28 Dec 2021 15:10:10 +0800 Subject: [PATCH 5/5] modify optimizer --- passl/hooks/optimizer_hook.py | 130 ++++++++++++++++++---------------- 1 file changed, 67 insertions(+), 63 deletions(-) diff --git a/passl/hooks/optimizer_hook.py b/passl/hooks/optimizer_hook.py index 1728638b..0e9cd3ec 100644 --- a/passl/hooks/optimizer_hook.py +++ b/passl/hooks/optimizer_hook.py @@ -1,64 +1,68 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .hook import Hook -from .builder import HOOKS - - -@HOOKS.register() -class OptimizerHook(Hook): - def __init__(self, priority=1): - self.priority = priority - - - def train_iter_end(self, trainer): - accumulate_steps = trainer.accumulate_grad_steps - if accumulate_steps > 1: - if trainer.current_iter % accumulate_steps == 0: - if 'Lars' in trainer.cfg['optimizer']['name']: - trainer.optimizer.clear_gradients() - else: - trainer.optimizer.clear_grad() - - loss = 0 - loss = trainer.outputs['loss'] / accumulate_steps - if trainer.use_amp: - scaled_loss = trainer.scaler.scale(loss) - scaled_loss.backward() - trainer.scaler.step(trainer.optimizer) - trainer.scaler.update() - - else: - loss.backward() - if 'lars' in trainer.optimizer.type: - trainer.optimizer.minimize(loss) - else: - trainer.optimizer.step() - else: - loss = 0 - loss = trainer.outputs['loss'] - if trainer.use_amp: - scaled_loss = trainer.scaler.scale(loss) - scaled_loss.backward() - trainer.scaler.step(trainer.optimizer) +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .hook import Hook +from .builder import HOOKS + + +@HOOKS.register() +class OptimizerHook(Hook): + def __init__(self, priority=1): + self.priority = priority + + def train_iter_end(self, trainer): + accumulate_steps = trainer.accumulate_grad_steps + if accumulate_steps > 1: + if trainer.current_iter % accumulate_steps == 0: + if 'Lars' in trainer.cfg['optimizer']['name']: + trainer.optimizer.clear_gradients() + else: + trainer.optimizer.clear_grad() + + loss = trainer.outputs['loss'] / accumulate_steps + if trainer.use_amp: + scaled_loss = trainer.scaler.scale(loss) + scaled_loss.backward() + trainer.scaler.step(trainer.optimizer) + trainer.scaler.update() + + else: + loss.backward() + if 'lars' in trainer.optimizer.type: + trainer.optimizer.minimize(loss) + else: + trainer.optimizer.step() + else: + loss = trainer.outputs['loss'] / accumulate_steps + if trainer.use_amp: + scaled_loss = trainer.scaler.scale(loss) + scaled_loss.backward() + else: + loss.backward() + else: + loss = trainer.outputs['loss'] + if trainer.use_amp: + scaled_loss = trainer.scaler.scale(loss) + scaled_loss.backward() + trainer.scaler.step(trainer.optimizer) trainer.scaler.update() - else: - loss.backward() - if 'lars' in trainer.optimizer.type: - trainer.optimizer.minimize(loss) - else: - trainer.optimizer.step() - - if 'loss' not in trainer.outputs: - trainer.outputs['loss'] = loss + else: + loss.backward() + if 'lars' in trainer.optimizer.type: + trainer.optimizer.minimize(loss) + else: + trainer.optimizer.step() + + if 'loss' not in trainer.outputs: + trainer.outputs['loss'] = loss