Skip to content

Commit

Permalink
Add byol, moco, simsiam model classes
Browse files Browse the repository at this point in the history
- Runs without code exceptions
  • Loading branch information
itsnamgyu committed Jan 4, 2022
1 parent d86a38c commit 92193b4
Show file tree
Hide file tree
Showing 4 changed files with 359 additions and 0 deletions.
6 changes: 6 additions & 0 deletions model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from model.base import BaseModel
from model.byol import BYOL
from model.moco import MoCo
from model.simclr import SimCLR
from model.simsiam import SimSiam

_model_class_map = {
'base': BaseModel,
'simclr': SimCLR,
'byol': BYOL,
'moco': MoCo,
'simsiam': SimSiam,
}


Expand Down
235 changes: 235 additions & 0 deletions model/byol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
import copy
import random
from argparse import Namespace
from functools import wraps

import torch
import torch.nn.functional as F
from torch import nn

from model.base import BaseSelfSupervisedModel


def _singleton(cache_key):
def inner_fn(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
instance = getattr(self, cache_key)
if instance is not None:
return instance

instance = fn(self, *args, **kwargs)
setattr(self, cache_key, instance)
return instance

return wrapper

return inner_fn


def _get_module_device(module):
return next(module.parameters()).device


def _set_requires_grad(model, val):
for p in model.parameters():
p.requires_grad = val


def _loss_fn(x, y):
x = F.normalize(x, dim=-1, p=2)
y = F.normalize(y, dim=-1, p=2)
return 2 - 2 * (x * y).sum(dim=-1)


class RandomApply(nn.Module):
def __init__(self, fn, p):
super().__init__()
self.fn = fn
self.p = p

def forward(self, x):
if random.random() > self.p:
return x
return self.fn(x)


class EMA:
def __init__(self, beta):
super().__init__()
self.beta = beta

def update_average(self, old, new):
if old is None:
return new
return old * self.beta + (1 - self.beta) * new


def _update_moving_average(ema_updater, ma_model, current_model):
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = ema_updater.update_average(old_weight, up_weight)


class MLP(nn.Module):
def __init__(self, dim, projection_size, hidden_size=4096):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_size),
nn.BatchNorm1d(hidden_size),
nn.ReLU(inplace=True),
nn.Linear(hidden_size, projection_size)
)

def forward(self, x):
return self.net(x)


class NetWrapper(nn.Module):
def __init__(self, net, projection_size, projection_hidden_size,
layer=-1): # default layer = -2 since network includes classifier. Ours does not have classifier.
super().__init__()
self.net = net
self.layer = layer

self.projector = None
self.projection_size = projection_size
self.projection_hidden_size = projection_hidden_size

self.hidden = {}
self.hook_registered = False

def _find_layer(self):
if type(self.layer) == str:
modules = dict([*self.net.named_modules()])
return modules.get(self.layer, None)
elif type(self.layer) == int:
children = [*self.net.children()]
return children[self.layer]
return None

def _hook(self, _, input, output):
device = input[0].device
self.hidden[device] = output.reshape(output.shape[0], -1) # flatten

def _register_hook(self):
layer = self._find_layer()
assert layer is not None, f'hidden layer ({self.layer}) not found'
handle = layer.register_forward_hook(self._hook)
self.hook_registered = True

@_singleton('projector')
def _get_projector(self, hidden):
_, dim = hidden.shape
projector = MLP(dim, self.projection_size, self.projection_hidden_size)
return projector.to(hidden)

def get_representation(self, x):
if self.layer == -1:
return self.net(x)

if not self.hook_registered:
self._register_hook()

self.hidden.clear()
_ = self.net(x)
hidden = self.hidden[x.device]
self.hidden.clear()

assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
return hidden

def forward(self, x, return_projection=True):
representation = self.get_representation(x)

if not return_projection:
return representation

projector = self._get_projector(representation)
projection = projector(representation)
return projection, representation


class BYOL(BaseSelfSupervisedModel):
def __init__(self, backbone: nn.Module, params: Namespace, use_momentum=True):
super().__init__(backbone, params)

image_size = 224
hidden_layer = -1
projection_size = 256
projection_hidden_size = 4096
moving_average_decay = 0.99
use_momentum = use_momentum

self.online_encoder = NetWrapper(self.backbone, projection_size, projection_hidden_size, layer=hidden_layer)

self.use_momentum = use_momentum
self.target_encoder = None
self.target_ema_updater = EMA(moving_average_decay)

self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size)

# get device of network and make wrapper same device
device = _get_module_device(backbone)
self.to(device)

# send a mock image tensor to instantiate singleton parameters
self.compute_ssl_loss(torch.randn(2, 3, image_size, image_size, device=device),
torch.randn(2, 3, image_size, image_size, device=device))

@_singleton('target_encoder')
def _get_target_encoder(self):
target_encoder = copy.deepcopy(self.online_encoder)
_set_requires_grad(target_encoder, False)
return target_encoder

def _reset_moving_average(self):
del self.target_encoder
self.target_encoder = None

def _update_moving_average(self):
assert self.use_momentum, 'you do not need to update the moving average, since you have turned off momentum for the target encoder'
assert self.target_encoder is not None, 'target encoder has not been created yet'
_update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)

def compute_ssl_loss(self, x1, x2=None, return_features=False):
if x2 is None:
x = x1
batch_size = int(x.shape[0] / 2)
x1 = x[:batch_size]
x2 = x[batch_size:]

assert not (self.training and x1.shape[
0] == 1), 'you must have greater than 1 sample when training, due to the batchnorm in the projection layer'

online_proj_one, _ = self.online_encoder(x1)
online_proj_two, _ = self.online_encoder(x2)

online_pred_one = self.online_predictor(online_proj_one)
online_pred_two = self.online_predictor(online_proj_two)

with torch.no_grad():
target_encoder = self._get_target_encoder() if self.use_momentum else self.online_encoder
target_proj_one, _ = target_encoder(x1)
target_proj_two, _ = target_encoder(x2)
target_proj_one.detach_()
target_proj_two.detach_()

loss_one = _loss_fn(online_pred_one, target_proj_two.detach())
loss_two = _loss_fn(online_pred_two, target_proj_one.detach())

loss = loss_one + loss_two
loss = loss.mean()

if return_features:
if x2 is None:
return loss, torch.cat([online_proj_one, online_proj_two])
else:
return loss, online_proj_one, online_proj_two
else:
return loss

def on_step_end(self):
if self.use_momentum:
self._update_moving_average()

108 changes: 108 additions & 0 deletions model/moco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import copy
from argparse import Namespace

import torch
import torch.nn.functional as F
from torch import nn

from model.base import BaseSelfSupervisedModel


class MoCo(BaseSelfSupervisedModel):
def __init__(self, backbone: nn.Module, params: Namespace):
super().__init__(backbone, params)

dim = 128
mlp = False
self.K = 1024
self.m = 0.999
self.T = 1.0

self.encoder_q = self.backbone
self.encoder_k = copy.deepcopy(self.backbone)

if not mlp:
self.projector_q = nn.Linear(self.encoder_q.final_feat_dim, dim)
self.projector_k = nn.Linear(self.encoder_k.final_feat_dim, dim)
else:
mlp_dim = self.encoder_q.feature.final_feat_dim
self.projector_q = nn.Sequential(nn.Linear(mlp_dim, mlp_dim), nn.ReLU(), nn.Linear(mlp_dim, dim))
self.projector_k = nn.Sequential(nn.Linear(mlp_dim, mlp_dim), nn.ReLU(), nn.Linear(mlp_dim, dim))

self.encoder_k.requires_grad_(False)
self.projector_k.requires_grad_(False)
# Just in case (copied from old code)
for param_k in self.encoder_k.parameters():
param_k.requires_grad = False
for param_k in self.projector_k.parameters():
param_k.requires_grad = False

self.register_buffer("queue", torch.randn(dim, self.K))
self.queue = F.normalize(self.queue, dim=0)
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

self.ce_loss = nn.CrossEntropyLoss()

@torch.no_grad()
def _momentum_update_key_encoder(self):
"""
Momentum update of the key encoder
"""
for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
for param_q_, param_k_ in zip(self.projector_q.parameters(), self.projector_k.parameters()):
param_k_.data = param_k_.data * self.m + param_q_.data * (1. - self.m)

@torch.no_grad()
def _dequeue_and_enqueue(self, keys):
batch_size = keys.shape[0]
ptr = int(self.queue_ptr)
assert self.K % batch_size == 0 # for simplicity

# replace the keys at ptr (dequeue and enqueue)
self.queue[:, ptr:ptr + batch_size] = keys.T
ptr = (ptr + batch_size) % self.K # move pointer

self.queue_ptr[0] = ptr

def compute_ssl_loss(self, x1, x2=None, return_features=False):
if x2 is None:
x = x1
batch_size = int(x.shape[0] / 2)
im_q = x[:batch_size]
im_k = x[batch_size:]
else:
im_q = x1
im_k = x2

q_features = self.encoder_q(im_q)
q = self.projector_q(q_features) # queries: NxC
q = F.normalize(q, dim=1)

# compute key features
with torch.no_grad(): # no gradient to keys
self._momentum_update_key_encoder() # update the key encoder

k_features = self.encoder_k(im_k)
k = self.projector_k(k_features) # keys: NxC
k = F.normalize(k, dim=1)

# compute logits (Einstein sum is more intuitive)
l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) # positive logits: Nx1
l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) # negative logits: NxK

logits = torch.cat([l_pos, l_neg], dim=1) # logits: Nx(1+K)
logits /= self.T # apply temperature
labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() # labels: positive key indicators

self._dequeue_and_enqueue(k)

loss = self.ce_loss(logits, labels)

if return_features:
if x2 is None:
return loss, torch.cat([q_features, k_features])
else:
return loss, q_features, k_features
else:
return loss
10 changes: 10 additions & 0 deletions model/simsiam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from argparse import Namespace

from torch import nn

from model import BYOL


class SimSiam(BYOL):
def __init__(self, backbone: nn.Module, params: Namespace):
super().__init__(backbone, params, use_momentum=False)

0 comments on commit 92193b4

Please sign in to comment.