diff --git a/README.md b/README.md
index 816bf731..7a8938eb 100644
--- a/README.md
+++ b/README.md
@@ -23,19 +23,13 @@ This repo is the official implementation of ["Swin Transformer: Hierarchical Vis
> **Mixture-of-Experts**: See [get_started](get_started.md#mixture-of-experts-support) for more instructions.
-> **Feature-Distillation**: Will appear in [Feature-Distillation](https://github.com/SwinTransformer/Feature-Distillation).
-
-## Activity notification
-
-* 09/18/2022: Organizing ECCV Workshop [*Computer Vision in the Wild (CVinW)*](https://computer-vision-in-the-wild.github.io/eccv-2022/), where two challenges are hosted to evaluate the zero-shot, few-shot and full-shot performance of pre-trained vision models in downstream tasks:
- - [``*Image Classification in the Wild (ICinW)*''](https://eval.ai/web/challenges/challenge-page/1832/overview) Challenge evaluates on 20 image classification tasks.
- - [``*Object Detection in the Wild (ODinW)*''](https://eval.ai/web/challenges/challenge-page/1839/overview) Challenge evaluates on 35 object detection tasks.
+> **Feature-Distillation**: See [Feature-Distillation](https://github.com/SwinTransformer/Feature-Distillation).
+## Updates
-$\qquad$ [
[Workshop]](https://computer-vision-in-the-wild.github.io/eccv-2022/) $\qquad$ [
[IC Challenge] ](https://eval.ai/web/challenges/challenge-page/1832/overview)
-$\qquad$ [
[OD Challenge] ](https://eval.ai/web/challenges/challenge-page/1839/overview)
+***11/30/2022***
-## Updates
+1. Models and codes of **Feature Distillation** are released. Please refer to [Feature-Distillation](https://github.com/SwinTransformer/Feature-Distillation) for details, and the checkpoints (FD-EsViT-Swin-B, FD-DeiT-ViT-B, FD-DINO-ViT-B, FD-CLIP-ViT-B, FD-CLIP-ViT-L).
***09/24/2022***
diff --git a/config.py b/config.py
index 1671ec34..767459f8 100644
--- a/config.py
+++ b/config.py
@@ -245,7 +245,7 @@
# Tag of experiment, overwritten by command line argument
_C.TAG = 'default'
# Frequency to save checkpoint
-_C.SAVE_FREQ = 1
+_C.SAVE_FREQ = 10
# Frequency to logging info
_C.PRINT_FREQ = 10
# Fixed random seed
diff --git a/configs/swinv2/swinv2_base_patch4_window8_128_hvd.yaml b/configs/swinv2/swinv2_base_patch4_window8_128_hvd.yaml
new file mode 100755
index 00000000..944fb07d
--- /dev/null
+++ b/configs/swinv2/swinv2_base_patch4_window8_128_hvd.yaml
@@ -0,0 +1,11 @@
+DATA:
+ IMG_SIZE: 128
+MODEL:
+ TYPE: swinv2
+ NAME: swinv2_base_patch4_window8_128_hvd
+ DROP_PATH_RATE: 0.5
+ SWINV2:
+ EMBED_DIM: 128
+ DEPTHS: [ 2, 2, 18, 2 ]
+ NUM_HEADS: [ 4, 8, 16, 32 ]
+ WINDOW_SIZE: 8
diff --git a/data/__init__.py b/data/__init__.py
index 5baad7ed..c41d6559 100644
--- a/data/__init__.py
+++ b/data/__init__.py
@@ -1,4 +1,5 @@
from .build import build_loader as _build_loader
+from .build import build_loader_hvd as _build_loader_hvd
from .data_simmim_pt import build_loader_simmim
from .data_simmim_ft import build_loader_finetune
@@ -10,3 +11,7 @@ def build_loader(config, simmim=False, is_pretrain=False):
return build_loader_simmim(config)
else:
return build_loader_finetune(config)
+
+def build_loader_hvd(config):
+ return _build_loader_hvd(config)
+
diff --git a/data/build.py b/data/build.py
index 5799f253..52677e78 100644
--- a/data/build.py
+++ b/data/build.py
@@ -9,6 +9,7 @@
import torch
import numpy as np
import torch.distributed as dist
+import horovod.torch as hvd
from torchvision import datasets, transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.data import Mixup
@@ -95,6 +96,60 @@ def build_loader(config):
return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn
+def build_loader_hvd(config):
+ config.defrost()
+ dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config)
+ config.freeze()
+ print(f"local rank {config.LOCAL_RANK} / global rank {hvd.rank()} successfully build train dataset")
+ dataset_val, _ = build_dataset(is_train=False, config=config)
+ print(f"local rank {config.LOCAL_RANK} / global rank {hvd.rank()} successfully build val dataset")
+
+ num_tasks = hvd.size()
+ global_rank = hvd.rank()
+ if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part':
+ indices = np.arange(hvd.rank(), len(dataset_train), hvd.size())
+ sampler_train = SubsetRandomSampler(indices)
+ else:
+ sampler_train = torch.utils.data.DistributedSampler(
+ dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
+ )
+
+ if config.TEST.SEQUENTIAL:
+ sampler_val = torch.utils.data.SequentialSampler(dataset_val)
+ else:
+ sampler_val = torch.utils.data.distributed.DistributedSampler(
+ dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=config.TEST.SHUFFLE
+ )
+
+ data_loader_train = torch.utils.data.DataLoader(
+ dataset_train, sampler=sampler_train,
+ batch_size=config.DATA.BATCH_SIZE,
+ num_workers=config.DATA.NUM_WORKERS,
+ pin_memory=config.DATA.PIN_MEMORY,
+ drop_last=True,
+ )
+
+ data_loader_val = torch.utils.data.DataLoader(
+ dataset_val, sampler=sampler_val,
+ batch_size=config.DATA.BATCH_SIZE,
+ shuffle=False,
+ num_workers=config.DATA.NUM_WORKERS,
+ pin_memory=config.DATA.PIN_MEMORY,
+ drop_last=False
+ )
+
+ # setup mixup / cutmix
+ mixup_fn = None
+ mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None
+ if mixup_active:
+ mixup_fn = Mixup(
+ mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX,
+ prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE,
+ label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES)
+
+ return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn
+
+
def build_dataset(is_train, config):
transform = build_transform(is_train, config)
if config.DATA.DATASET == 'imagenet':
diff --git a/main_hvd.py b/main_hvd.py
new file mode 100755
index 00000000..72312826
--- /dev/null
+++ b/main_hvd.py
@@ -0,0 +1,353 @@
+# --------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ze Liu
+# --------------------------------------------------------
+
+import os
+import time
+import json
+import random
+import argparse
+import datetime
+import numpy as np
+
+import torch
+import torch.backends.cudnn as cudnn
+import torch.distributed as dist
+import horovod.torch as hvd
+
+from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
+from timm.utils import accuracy, AverageMeter
+
+from config import get_config
+from models import build_model
+from data import build_loader, build_loader_hvd
+from lr_scheduler import build_scheduler
+from optimizer import build_optimizer
+from logger import create_logger
+from utils import load_checkpoint, load_pretrained, save_checkpoint, NativeScalerWithGradNormCount, NativeScalerWithGradNormCountHvd, auto_resume_helper, \
+ reduce_tensor, reduce_tensor_hvd
+
+
+def parse_option():
+ parser = argparse.ArgumentParser('Swin Transformer training and evaluation script', add_help=False)
+ parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', )
+ parser.add_argument(
+ "--opts",
+ help="Modify config options by adding 'KEY VALUE' pairs. ",
+ default=None,
+ nargs='+',
+ )
+
+ # easy config modification
+ parser.add_argument('--batch-size', type=int, help="batch size for single GPU")
+ parser.add_argument('--data-path', type=str, help='path to dataset')
+ parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')
+ parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],
+ help='no: no cache, '
+ 'full: cache all data, '
+ 'part: sharding the dataset into nonoverlapping pieces and only cache one piece')
+ parser.add_argument('--pretrained',
+ help='pretrained weight from checkpoint, could be imagenet22k pretrained weight')
+ parser.add_argument('--resume', help='resume from checkpoint')
+ parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
+ parser.add_argument('--use-checkpoint', action='store_true',
+ help="whether to use gradient checkpointing to save memory")
+ parser.add_argument('--disable_amp', action='store_true', help='Disable pytorch amp')
+ parser.add_argument('--amp-opt-level', type=str, choices=['O0', 'O1', 'O2'],
+ help='mixed precision opt level, if O0, no amp is used (deprecated!)')
+ parser.add_argument('--output', default='output', type=str, metavar='PATH',
+ help='root of output folder, the full path is