diff --git a/README.md b/README.md index 816bf731..7a8938eb 100644 --- a/README.md +++ b/README.md @@ -23,19 +23,13 @@ This repo is the official implementation of ["Swin Transformer: Hierarchical Vis > **Mixture-of-Experts**: See [get_started](get_started.md#mixture-of-experts-support) for more instructions. -> **Feature-Distillation**: Will appear in [Feature-Distillation](https://github.com/SwinTransformer/Feature-Distillation). - -## Activity notification - -* 09/18/2022: Organizing ECCV Workshop [*Computer Vision in the Wild (CVinW)*](https://computer-vision-in-the-wild.github.io/eccv-2022/), where two challenges are hosted to evaluate the zero-shot, few-shot and full-shot performance of pre-trained vision models in downstream tasks: - - [``*Image Classification in the Wild (ICinW)*''](https://eval.ai/web/challenges/challenge-page/1832/overview) Challenge evaluates on 20 image classification tasks. - - [``*Object Detection in the Wild (ODinW)*''](https://eval.ai/web/challenges/challenge-page/1839/overview) Challenge evaluates on 35 object detection tasks. +> **Feature-Distillation**: See [Feature-Distillation](https://github.com/SwinTransformer/Feature-Distillation). +## Updates -$\qquad$ [ [Workshop]](https://computer-vision-in-the-wild.github.io/eccv-2022/) $\qquad$ [ [IC Challenge] ](https://eval.ai/web/challenges/challenge-page/1832/overview) -$\qquad$ [ [OD Challenge] ](https://eval.ai/web/challenges/challenge-page/1839/overview) +***11/30/2022*** -## Updates +1. Models and codes of **Feature Distillation** are released. Please refer to [Feature-Distillation](https://github.com/SwinTransformer/Feature-Distillation) for details, and the checkpoints (FD-EsViT-Swin-B, FD-DeiT-ViT-B, FD-DINO-ViT-B, FD-CLIP-ViT-B, FD-CLIP-ViT-L). ***09/24/2022*** diff --git a/config.py b/config.py index 1671ec34..767459f8 100644 --- a/config.py +++ b/config.py @@ -245,7 +245,7 @@ # Tag of experiment, overwritten by command line argument _C.TAG = 'default' # Frequency to save checkpoint -_C.SAVE_FREQ = 1 +_C.SAVE_FREQ = 10 # Frequency to logging info _C.PRINT_FREQ = 10 # Fixed random seed diff --git a/configs/swinv2/swinv2_base_patch4_window8_128_hvd.yaml b/configs/swinv2/swinv2_base_patch4_window8_128_hvd.yaml new file mode 100755 index 00000000..944fb07d --- /dev/null +++ b/configs/swinv2/swinv2_base_patch4_window8_128_hvd.yaml @@ -0,0 +1,11 @@ +DATA: + IMG_SIZE: 128 +MODEL: + TYPE: swinv2 + NAME: swinv2_base_patch4_window8_128_hvd + DROP_PATH_RATE: 0.5 + SWINV2: + EMBED_DIM: 128 + DEPTHS: [ 2, 2, 18, 2 ] + NUM_HEADS: [ 4, 8, 16, 32 ] + WINDOW_SIZE: 8 diff --git a/data/__init__.py b/data/__init__.py index 5baad7ed..c41d6559 100644 --- a/data/__init__.py +++ b/data/__init__.py @@ -1,4 +1,5 @@ from .build import build_loader as _build_loader +from .build import build_loader_hvd as _build_loader_hvd from .data_simmim_pt import build_loader_simmim from .data_simmim_ft import build_loader_finetune @@ -10,3 +11,7 @@ def build_loader(config, simmim=False, is_pretrain=False): return build_loader_simmim(config) else: return build_loader_finetune(config) + +def build_loader_hvd(config): + return _build_loader_hvd(config) + diff --git a/data/build.py b/data/build.py index 5799f253..52677e78 100644 --- a/data/build.py +++ b/data/build.py @@ -9,6 +9,7 @@ import torch import numpy as np import torch.distributed as dist +import horovod.torch as hvd from torchvision import datasets, transforms from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import Mixup @@ -95,6 +96,60 @@ def build_loader(config): return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn +def build_loader_hvd(config): + config.defrost() + dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config) + config.freeze() + print(f"local rank {config.LOCAL_RANK} / global rank {hvd.rank()} successfully build train dataset") + dataset_val, _ = build_dataset(is_train=False, config=config) + print(f"local rank {config.LOCAL_RANK} / global rank {hvd.rank()} successfully build val dataset") + + num_tasks = hvd.size() + global_rank = hvd.rank() + if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part': + indices = np.arange(hvd.rank(), len(dataset_train), hvd.size()) + sampler_train = SubsetRandomSampler(indices) + else: + sampler_train = torch.utils.data.DistributedSampler( + dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True + ) + + if config.TEST.SEQUENTIAL: + sampler_val = torch.utils.data.SequentialSampler(dataset_val) + else: + sampler_val = torch.utils.data.distributed.DistributedSampler( + dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=config.TEST.SHUFFLE + ) + + data_loader_train = torch.utils.data.DataLoader( + dataset_train, sampler=sampler_train, + batch_size=config.DATA.BATCH_SIZE, + num_workers=config.DATA.NUM_WORKERS, + pin_memory=config.DATA.PIN_MEMORY, + drop_last=True, + ) + + data_loader_val = torch.utils.data.DataLoader( + dataset_val, sampler=sampler_val, + batch_size=config.DATA.BATCH_SIZE, + shuffle=False, + num_workers=config.DATA.NUM_WORKERS, + pin_memory=config.DATA.PIN_MEMORY, + drop_last=False + ) + + # setup mixup / cutmix + mixup_fn = None + mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None + if mixup_active: + mixup_fn = Mixup( + mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX, + prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE, + label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES) + + return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn + + def build_dataset(is_train, config): transform = build_transform(is_train, config) if config.DATA.DATASET == 'imagenet': diff --git a/main_hvd.py b/main_hvd.py new file mode 100755 index 00000000..72312826 --- /dev/null +++ b/main_hvd.py @@ -0,0 +1,353 @@ +# -------------------------------------------------------- +# Swin Transformer +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ze Liu +# -------------------------------------------------------- + +import os +import time +import json +import random +import argparse +import datetime +import numpy as np + +import torch +import torch.backends.cudnn as cudnn +import torch.distributed as dist +import horovod.torch as hvd + +from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy +from timm.utils import accuracy, AverageMeter + +from config import get_config +from models import build_model +from data import build_loader, build_loader_hvd +from lr_scheduler import build_scheduler +from optimizer import build_optimizer +from logger import create_logger +from utils import load_checkpoint, load_pretrained, save_checkpoint, NativeScalerWithGradNormCount, NativeScalerWithGradNormCountHvd, auto_resume_helper, \ + reduce_tensor, reduce_tensor_hvd + + +def parse_option(): + parser = argparse.ArgumentParser('Swin Transformer training and evaluation script', add_help=False) + parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', ) + parser.add_argument( + "--opts", + help="Modify config options by adding 'KEY VALUE' pairs. ", + default=None, + nargs='+', + ) + + # easy config modification + parser.add_argument('--batch-size', type=int, help="batch size for single GPU") + parser.add_argument('--data-path', type=str, help='path to dataset') + parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset') + parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'], + help='no: no cache, ' + 'full: cache all data, ' + 'part: sharding the dataset into nonoverlapping pieces and only cache one piece') + parser.add_argument('--pretrained', + help='pretrained weight from checkpoint, could be imagenet22k pretrained weight') + parser.add_argument('--resume', help='resume from checkpoint') + parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps") + parser.add_argument('--use-checkpoint', action='store_true', + help="whether to use gradient checkpointing to save memory") + parser.add_argument('--disable_amp', action='store_true', help='Disable pytorch amp') + parser.add_argument('--amp-opt-level', type=str, choices=['O0', 'O1', 'O2'], + help='mixed precision opt level, if O0, no amp is used (deprecated!)') + parser.add_argument('--output', default='output', type=str, metavar='PATH', + help='root of output folder, the full path is // (default: output)') + parser.add_argument('--tag', help='tag of experiment') + parser.add_argument('--eval', action='store_true', help='Perform evaluation only') + parser.add_argument('--throughput', action='store_true', help='Test throughput only') + + # distributed training + parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel') + + # for acceleration + parser.add_argument('--fused_window_process', action='store_true', + help='Fused window shift & window partition, similar for reversed part.') + parser.add_argument('--fused_layernorm', action='store_true', help='Use fused layernorm.') + ## overwrite optimizer in config (*.yaml) if specified, e.g., fused_adam/fused_lamb + parser.add_argument('--optim', type=str, + help='overwrite optimizer if provided, can be adamw/sgd/fused_adam/fused_lamb.') + + args, unparsed = parser.parse_known_args() + + config = get_config(args) + + return args, config + + +def main(config): + dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader_hvd(config) + + logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") + model = build_model(config) + logger.info(str(model)) + + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + logger.info(f"number of params: {n_parameters}") + if hasattr(model, 'flops'): + flops = model.flops() + logger.info(f"number of GFLOPs: {flops / 1e9}") + + model.cuda() + model_without_ddp = model + + optimizer = build_optimizer(config, model) + optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters()) + loss_scaler = NativeScalerWithGradNormCountHvd() + + if config.TRAIN.ACCUMULATION_STEPS > 1: + lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train) // config.TRAIN.ACCUMULATION_STEPS) + else: + lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train)) + + if config.AUG.MIXUP > 0.: + # smoothing is handled with mixup label transform + criterion = SoftTargetCrossEntropy() + elif config.MODEL.LABEL_SMOOTHING > 0.: + criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING) + else: + criterion = torch.nn.CrossEntropyLoss() + + max_accuracy = 0.0 + + if config.TRAIN.AUTO_RESUME: + resume_file = auto_resume_helper(config.OUTPUT) + if resume_file: + if config.MODEL.RESUME: + logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}") + config.defrost() + config.MODEL.RESUME = resume_file + config.freeze() + logger.info(f'auto resuming from {resume_file}') + else: + logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume') + + if config.MODEL.RESUME: + max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, loss_scaler, logger) + acc1, acc5, loss = validate(config, data_loader_val, model) + logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") + if config.EVAL_MODE: + return + + if config.MODEL.PRETRAINED and (not config.MODEL.RESUME): + load_pretrained(config, model_without_ddp, logger) + acc1, acc5, loss = validate(config, data_loader_val, model) + logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") + + if config.THROUGHPUT_MODE: + throughput(data_loader_val, model, logger) + return + + hvd.broadcast_parameters(model.state_dict(), root_rank=0) + hvd.broadcast_optimizer_state(optimizer, root_rank=0) + logger.info("Start training") + start_time = time.time() + for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): + data_loader_train.sampler.set_epoch(epoch) + + train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler, + loss_scaler) + if hvd.rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)): + save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, loss_scaler, + logger) + + acc1, acc5, loss = validate(config, data_loader_val, model) + logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") + max_accuracy = max(max_accuracy, acc1) + logger.info(f'Max accuracy: {max_accuracy:.2f}%') + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + logger.info('Training time {}'.format(total_time_str)) + + +def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, loss_scaler): + model.train() + optimizer.zero_grad() + + num_steps = len(data_loader) + batch_time = AverageMeter() + loss_meter = AverageMeter() + norm_meter = AverageMeter() + scaler_meter = AverageMeter() + + start = time.time() + end = time.time() + for idx, (samples, targets) in enumerate(data_loader): + samples = samples.cuda(non_blocking=True) + targets = targets.cuda(non_blocking=True) + + if mixup_fn is not None: + samples, targets = mixup_fn(samples, targets) + + with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE): + outputs = model(samples) + loss = criterion(outputs, targets) + loss = loss / config.TRAIN.ACCUMULATION_STEPS + + # this attribute is added by timm on one optimizer (adahessian) + is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order + grad_norm = loss_scaler(loss, optimizer, clip_grad=config.TRAIN.CLIP_GRAD, + parameters=model.parameters(), create_graph=is_second_order, + update_grad=(idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0) + # grad_norm = None + if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: + optimizer.zero_grad() + lr_scheduler.step_update((epoch * num_steps + idx) // config.TRAIN.ACCUMULATION_STEPS) + loss_scale_value = 1.0 # loss_scaler.state_dict()["scale"] + torch.cuda.synchronize() + + loss_meter.update(loss.item(), targets.size(0)) + if grad_norm is not None: # loss_scaler return None if not update + norm_meter.update(grad_norm) + scaler_meter.update(loss_scale_value) + batch_time.update(time.time() - end) + end = time.time() + + if idx % config.PRINT_FREQ == 0: + lr = optimizer.param_groups[0]['lr'] + wd = optimizer.param_groups[0]['weight_decay'] + memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) + etas = batch_time.avg * (num_steps - idx) + logger.info( + f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' + f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t wd {wd:.4f}\t' + f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' + f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' + f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t' + f'loss_scale {scaler_meter.val:.4f} ({scaler_meter.avg:.4f})\t' + f'mem {memory_used:.0f}MB') + epoch_time = time.time() - start + logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}") + + +@torch.no_grad() +def validate(config, data_loader, model): + criterion = torch.nn.CrossEntropyLoss() + model.eval() + + batch_time = AverageMeter() + loss_meter = AverageMeter() + acc1_meter = AverageMeter() + acc5_meter = AverageMeter() + + end = time.time() + for idx, (images, target) in enumerate(data_loader): + images = images.cuda(non_blocking=True) + target = target.cuda(non_blocking=True) + + # compute output + with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE): + output = model(images) + + # measure accuracy and record loss + loss = criterion(output, target) + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + + acc1 = reduce_tensor_hvd(acc1) + acc5 = reduce_tensor_hvd(acc5) + loss = reduce_tensor_hvd(loss) + + loss_meter.update(loss.item(), target.size(0)) + acc1_meter.update(acc1.item(), target.size(0)) + acc5_meter.update(acc5.item(), target.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if idx % config.PRINT_FREQ == 0: + memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) + logger.info( + f'Test: [{idx}/{len(data_loader)}]\t' + f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' + f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t' + f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t' + f'Mem {memory_used:.0f}MB') + logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}') + return acc1_meter.avg, acc5_meter.avg, loss_meter.avg + + +@torch.no_grad() +def throughput(data_loader, model, logger): + model.eval() + + for idx, (images, _) in enumerate(data_loader): + images = images.cuda(non_blocking=True) + batch_size = images.shape[0] + for i in range(50): + model(images) + torch.cuda.synchronize() + logger.info(f"throughput averaged with 30 times") + tic1 = time.time() + for i in range(30): + model(images) + torch.cuda.synchronize() + tic2 = time.time() + logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}") + return + + +if __name__ == '__main__': + args, config = parse_option() + + if config.AMP_OPT_LEVEL: + print("[warning] Apex amp has been deprecated, please use pytorch amp instead!") + + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + rank = int(os.environ["RANK"]) + world_size = int(os.environ['WORLD_SIZE']) + print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}") + else: + rank = -1 + world_size = -1 + + hvd.init() + torch.cuda.set_device(hvd.local_rank()) + # torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank) + # torch.distributed.barrier() + + seed = config.SEED + hvd.rank() + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + cudnn.benchmark = True + + # linear scale the learning rate according to total batch size, may not be optimal + linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * hvd.size() / 512.0 + linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * hvd.size() / 512.0 + linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * hvd.size() / 512.0 + # gradient accumulation also need to scale the learning rate + if config.TRAIN.ACCUMULATION_STEPS > 1: + linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS + linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS + linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS + config.defrost() + config.TRAIN.BASE_LR = linear_scaled_lr + config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr + config.TRAIN.MIN_LR = linear_scaled_min_lr + config.freeze() + + os.makedirs(config.OUTPUT, exist_ok=True) + logger = create_logger(output_dir=config.OUTPUT, dist_rank=hvd.rank(), name=f"{config.MODEL.NAME}") + + if hvd.rank() == 0: + path = os.path.join(config.OUTPUT, "config.json") + with open(path, "w") as f: + f.write(config.dump()) + logger.info(f"Full config saved to {path}") + + # print config + logger.info(config.dump()) + logger.info(json.dumps(vars(args))) + + main(config) diff --git a/train.sh b/train.sh new file mode 100644 index 00000000..88805ed4 --- /dev/null +++ b/train.sh @@ -0,0 +1 @@ +horovodrun -np 4 python main_hvd.py --cfg configs/swinv2/swinv2_base_patch4_window8_128_hvd.yaml --data-path ../../datasets/tiny-imagenet-200 --batch-size 64 --disable_amp --local_rank 0 diff --git a/utils.py b/utils.py index eb607cfe..4d80bbf2 100644 --- a/utils.py +++ b/utils.py @@ -8,6 +8,7 @@ import os import torch import torch.distributed as dist +import horovod.torch as hvd from torch._six import inf @@ -176,6 +177,13 @@ def reduce_tensor(tensor): return rt +def reduce_tensor_hvd(tensor): + rt = tensor.clone() + avg_rt = hvd.allreduce(rt) + # rt /= hvd.size() + return avg_rt + + def ampscaler_get_grad_norm(parameters, norm_type: float = 2.0) -> torch.Tensor: if isinstance(parameters, torch.Tensor): parameters = [parameters] @@ -192,11 +200,42 @@ def ampscaler_get_grad_norm(parameters, norm_type: float = 2.0) -> torch.Tensor: return total_norm +class NativeScalerWithGradNormCountHvd: + state_dict_key = "amp_scaler" + + def __init__(self): + self._scaler = torch.cuda.amp.GradScaler(enabled=False) + + def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): + self._scaler.scale(loss).backward(create_graph=create_graph) + optimizer.synchronize() + if update_grad: + if clip_grad is not None: + assert parameters is not None + self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place + norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) + else: + self._scaler.unscale_(optimizer) + norm = ampscaler_get_grad_norm(parameters) + with optimizer.skip_synchronize(): + self._scaler.step(optimizer) + self._scaler.update() + else: + norm = None + return norm + + def state_dict(self): + return self._scaler.state_dict() + + def load_state_dict(self, state_dict): + self._scaler.load_state_dict(state_dict) + + class NativeScalerWithGradNormCount: state_dict_key = "amp_scaler" def __init__(self): - self._scaler = torch.cuda.amp.GradScaler() + self._scaler = torch.cuda.amp.GradScaler(enabled=False) def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): self._scaler.scale(loss).backward(create_graph=create_graph) @@ -218,4 +257,4 @@ def state_dict(self): return self._scaler.state_dict() def load_state_dict(self, state_dict): - self._scaler.load_state_dict(state_dict) + self._scaler.load_state_dict(state_dict) \ No newline at end of file