Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 4 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,13 @@ This repo is the official implementation of ["Swin Transformer: Hierarchical Vis

> **Mixture-of-Experts**: See [get_started](get_started.md#mixture-of-experts-support) for more instructions.

> **Feature-Distillation**: Will appear in [Feature-Distillation](https://github.com/SwinTransformer/Feature-Distillation).

## Activity notification

* 09/18/2022: Organizing ECCV Workshop [*Computer Vision in the Wild (CVinW)*](https://computer-vision-in-the-wild.github.io/eccv-2022/), where two challenges are hosted to evaluate the zero-shot, few-shot and full-shot performance of pre-trained vision models in downstream tasks:
- [``*Image Classification in the Wild (ICinW)*''](https://eval.ai/web/challenges/challenge-page/1832/overview) Challenge evaluates on 20 image classification tasks.
- [``*Object Detection in the Wild (ODinW)*''](https://eval.ai/web/challenges/challenge-page/1839/overview) Challenge evaluates on 35 object detection tasks.
> **Feature-Distillation**: See [Feature-Distillation](https://github.com/SwinTransformer/Feature-Distillation).

## Updates

$\qquad$ [ <img src="https://computer-vision-in-the-wild.github.io/eccv-2022/static/eccv2022/img/ECCV-logo3.png" width=10%/> [Workshop]](https://computer-vision-in-the-wild.github.io/eccv-2022/) $\qquad$ [<img src="https://evalai.s3.amazonaws.com/media/logos/4e939412-a9c0-46bd-9797-5ba0bd0a9095.jpg" width=10%/> [IC Challenge] ](https://eval.ai/web/challenges/challenge-page/1832/overview)
$\qquad$ [<img src="https://evalai.s3.amazonaws.com/media/logos/3a31ae6e-a990-48fb-b2c3-1e7da9d17a20.jpg" width=10%/> [OD Challenge] ](https://eval.ai/web/challenges/challenge-page/1839/overview)
***11/30/2022***

## Updates
1. Models and codes of **Feature Distillation** are released. Please refer to [Feature-Distillation](https://github.com/SwinTransformer/Feature-Distillation) for details, and the checkpoints (FD-EsViT-Swin-B, FD-DeiT-ViT-B, FD-DINO-ViT-B, FD-CLIP-ViT-B, FD-CLIP-ViT-L).

***09/24/2022***

Expand Down
6 changes: 4 additions & 2 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@
# -----------------------------------------------------------------------------
_C.TRAIN = CN()
_C.TRAIN.START_EPOCH = 0
_C.TRAIN.EPOCHS = 300
_C.TRAIN.EPOCHS = 90
_C.TRAIN.WARMUP_EPOCHS = 20
_C.TRAIN.WEIGHT_DECAY = 0.05
_C.TRAIN.BASE_LR = 5e-4
Expand Down Expand Up @@ -245,7 +245,7 @@
# Tag of experiment, overwritten by command line argument
_C.TAG = 'default'
# Frequency to save checkpoint
_C.SAVE_FREQ = 1
_C.SAVE_FREQ = 10
# Frequency to logging info
_C.PRINT_FREQ = 10
# Fixed random seed
Expand Down Expand Up @@ -291,6 +291,8 @@ def _check_args(name):
# merge from specific arguments
if _check_args('batch_size'):
config.DATA.BATCH_SIZE = args.batch_size
if _check_args('device'):
config.DEVICE = args.device
if _check_args('data_path'):
config.DATA.DATA_PATH = args.data_path
if _check_args('zip'):
Expand Down
11 changes: 11 additions & 0 deletions configs/swinv2/swinv2_base_patch4_window8_128_hvd.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
DATA:
IMG_SIZE: 128
MODEL:
TYPE: swinv2
NAME: swinv2_base_patch4_window8_128_hvd
DROP_PATH_RATE: 0.5
SWINV2:
EMBED_DIM: 128
DEPTHS: [ 2, 2, 18, 2 ]
NUM_HEADS: [ 4, 8, 16, 32 ]
WINDOW_SIZE: 8
5 changes: 5 additions & 0 deletions data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .build import build_loader as _build_loader
from .build import build_loader_hvd as _build_loader_hvd
from .data_simmim_pt import build_loader_simmim
from .data_simmim_ft import build_loader_finetune

Expand All @@ -10,3 +11,7 @@ def build_loader(config, simmim=False, is_pretrain=False):
return build_loader_simmim(config)
else:
return build_loader_finetune(config)

def build_loader_hvd(config):
return _build_loader_hvd(config)

62 changes: 59 additions & 3 deletions data/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
import torch
import numpy as np
import torch.distributed as dist
import horovod.torch as hvd
from torchvision import datasets, transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.data import Mixup
#from timm.data import Mixup
from .mixup import Mixup
from timm.data import create_transform

from .cached_image_folder import CachedImageFolder
Expand Down Expand Up @@ -90,7 +92,61 @@ def build_loader(config):
mixup_fn = Mixup(
mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX,
prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE,
label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES)
label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES, device=config.DEVICE)

return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn


def build_loader_hvd(config):
config.defrost()
dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config)
config.freeze()
print(f"local rank {config.LOCAL_RANK} / global rank {hvd.rank()} successfully build train dataset")
dataset_val, _ = build_dataset(is_train=False, config=config)
print(f"local rank {config.LOCAL_RANK} / global rank {hvd.rank()} successfully build val dataset")

num_tasks = hvd.size()
global_rank = hvd.rank()
if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part':
indices = np.arange(hvd.rank(), len(dataset_train), hvd.size())
sampler_train = SubsetRandomSampler(indices)
else:
sampler_train = torch.utils.data.DistributedSampler(
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
)

if config.TEST.SEQUENTIAL:
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
else:
sampler_val = torch.utils.data.distributed.DistributedSampler(
dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=config.TEST.SHUFFLE
)

data_loader_train = torch.utils.data.DataLoader(
dataset_train, sampler=sampler_train,
batch_size=config.DATA.BATCH_SIZE,
num_workers=config.DATA.NUM_WORKERS,
pin_memory=config.DATA.PIN_MEMORY,
drop_last=True,
)

data_loader_val = torch.utils.data.DataLoader(
dataset_val, sampler=sampler_val,
batch_size=config.DATA.BATCH_SIZE,
shuffle=False,
num_workers=config.DATA.NUM_WORKERS,
pin_memory=config.DATA.PIN_MEMORY,
drop_last=False
)

# setup mixup / cutmix
mixup_fn = None
mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None
if mixup_active:
mixup_fn = Mixup(
mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX,
prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE,
label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES, device=config.DEVICE)

return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn

Expand All @@ -107,7 +163,7 @@ def build_dataset(is_train, config):
else:
root = os.path.join(config.DATA.DATA_PATH, prefix)
dataset = datasets.ImageFolder(root, transform=transform)
nb_classes = 1000
nb_classes = 200
elif config.DATA.DATASET == 'imagenet22K':
prefix = 'ILSVRC2011fall_whole'
if is_train:
Expand Down
Loading