Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 4 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,13 @@ This repo is the official implementation of ["Swin Transformer: Hierarchical Vis

> **Mixture-of-Experts**: See [get_started](get_started.md#mixture-of-experts-support) for more instructions.

> **Feature-Distillation**: Will appear in [Feature-Distillation](https://github.com/SwinTransformer/Feature-Distillation).

## Activity notification

* 09/18/2022: Organizing ECCV Workshop [*Computer Vision in the Wild (CVinW)*](https://computer-vision-in-the-wild.github.io/eccv-2022/), where two challenges are hosted to evaluate the zero-shot, few-shot and full-shot performance of pre-trained vision models in downstream tasks:
- [``*Image Classification in the Wild (ICinW)*''](https://eval.ai/web/challenges/challenge-page/1832/overview) Challenge evaluates on 20 image classification tasks.
- [``*Object Detection in the Wild (ODinW)*''](https://eval.ai/web/challenges/challenge-page/1839/overview) Challenge evaluates on 35 object detection tasks.
> **Feature-Distillation**: See [Feature-Distillation](https://github.com/SwinTransformer/Feature-Distillation).

## Updates

$\qquad$ [ <img src="https://computer-vision-in-the-wild.github.io/eccv-2022/static/eccv2022/img/ECCV-logo3.png" width=10%/> [Workshop]](https://computer-vision-in-the-wild.github.io/eccv-2022/) $\qquad$ [<img src="https://evalai.s3.amazonaws.com/media/logos/4e939412-a9c0-46bd-9797-5ba0bd0a9095.jpg" width=10%/> [IC Challenge] ](https://eval.ai/web/challenges/challenge-page/1832/overview)
$\qquad$ [<img src="https://evalai.s3.amazonaws.com/media/logos/3a31ae6e-a990-48fb-b2c3-1e7da9d17a20.jpg" width=10%/> [OD Challenge] ](https://eval.ai/web/challenges/challenge-page/1839/overview)
***11/30/2022***

## Updates
1. Models and codes of **Feature Distillation** are released. Please refer to [Feature-Distillation](https://github.com/SwinTransformer/Feature-Distillation) for details, and the checkpoints (FD-EsViT-Swin-B, FD-DeiT-ViT-B, FD-DINO-ViT-B, FD-CLIP-ViT-B, FD-CLIP-ViT-L).

***09/24/2022***

Expand Down
2 changes: 1 addition & 1 deletion config.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@
# Tag of experiment, overwritten by command line argument
_C.TAG = 'default'
# Frequency to save checkpoint
_C.SAVE_FREQ = 1
_C.SAVE_FREQ = 10
# Frequency to logging info
_C.PRINT_FREQ = 10
# Fixed random seed
Expand Down
11 changes: 11 additions & 0 deletions configs/swinv2/swinv2_base_patch4_window8_128_hvd.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
DATA:
IMG_SIZE: 128
MODEL:
TYPE: swinv2
NAME: swinv2_base_patch4_window8_128_hvd
DROP_PATH_RATE: 0.5
SWINV2:
EMBED_DIM: 128
DEPTHS: [ 2, 2, 18, 2 ]
NUM_HEADS: [ 4, 8, 16, 32 ]
WINDOW_SIZE: 8
5 changes: 5 additions & 0 deletions data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .build import build_loader as _build_loader
from .build import build_loader_hvd as _build_loader_hvd
from .data_simmim_pt import build_loader_simmim
from .data_simmim_ft import build_loader_finetune

Expand All @@ -10,3 +11,7 @@ def build_loader(config, simmim=False, is_pretrain=False):
return build_loader_simmim(config)
else:
return build_loader_finetune(config)

def build_loader_hvd(config):
return _build_loader_hvd(config)

55 changes: 55 additions & 0 deletions data/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
import numpy as np
import torch.distributed as dist
import horovod.torch as hvd
from torchvision import datasets, transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.data import Mixup
Expand Down Expand Up @@ -95,6 +96,60 @@ def build_loader(config):
return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn


def build_loader_hvd(config):
config.defrost()
dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config)
config.freeze()
print(f"local rank {config.LOCAL_RANK} / global rank {hvd.rank()} successfully build train dataset")
dataset_val, _ = build_dataset(is_train=False, config=config)
print(f"local rank {config.LOCAL_RANK} / global rank {hvd.rank()} successfully build val dataset")

num_tasks = hvd.size()
global_rank = hvd.rank()
if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part':
indices = np.arange(hvd.rank(), len(dataset_train), hvd.size())
sampler_train = SubsetRandomSampler(indices)
else:
sampler_train = torch.utils.data.DistributedSampler(
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
)

if config.TEST.SEQUENTIAL:
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
else:
sampler_val = torch.utils.data.distributed.DistributedSampler(
dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=config.TEST.SHUFFLE
)

data_loader_train = torch.utils.data.DataLoader(
dataset_train, sampler=sampler_train,
batch_size=config.DATA.BATCH_SIZE,
num_workers=config.DATA.NUM_WORKERS,
pin_memory=config.DATA.PIN_MEMORY,
drop_last=True,
)

data_loader_val = torch.utils.data.DataLoader(
dataset_val, sampler=sampler_val,
batch_size=config.DATA.BATCH_SIZE,
shuffle=False,
num_workers=config.DATA.NUM_WORKERS,
pin_memory=config.DATA.PIN_MEMORY,
drop_last=False
)

# setup mixup / cutmix
mixup_fn = None
mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None
if mixup_active:
mixup_fn = Mixup(
mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX,
prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE,
label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES)

return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn


def build_dataset(is_train, config):
transform = build_transform(is_train, config)
if config.DATA.DATASET == 'imagenet':
Expand Down
Loading