-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Fatemeh Taheri
committed
Jan 10, 2021
0 parents
commit e9681bb
Showing
40 changed files
with
4,302 additions
and
0 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# coding=utf-8 | ||
# Copyright 2019 The Google Research Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""List all available algorithms.""" | ||
|
||
# from alignment import Alignment | ||
from algos.alignment_sal_tcn import AlignmentSaLTCN | ||
# from method.algos.classification import Classification | ||
# from method.algos.sal import SaL | ||
# from method.algos.tcn import TCN | ||
|
||
ALGO_NAME_TO_ALGO_CLASS = { | ||
'alignment_sal_tcn': AlignmentSaLTCN, | ||
} | ||
|
||
|
||
def get_algo(algo_name): | ||
"""Returns training algo.""" | ||
if algo_name not in ALGO_NAME_TO_ALGO_CLASS.keys(): | ||
raise ValueError('%s not supported yet.' % algo_name) | ||
algo = ALGO_NAME_TO_ALGO_CLASS[algo_name] | ||
return algo() |
Binary file not shown.
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import abc | ||
import tensorflow as tf | ||
|
||
from config import CONFIG | ||
from models import get_model | ||
from utils import get_cnn_feats | ||
from utils import set_learning_phase | ||
|
||
DEBUG = False | ||
|
||
class Algorithm(tf.keras.Model): | ||
"""Base class for defining algorithms.""" | ||
_metaclass_ = abc.ABCMeta | ||
|
||
def __init__(self, model=None): | ||
super(Algorithm, self).__init__() | ||
if model: | ||
self.model = model | ||
else: | ||
self.model = get_model() | ||
|
||
@set_learning_phase | ||
@abc.abstractmethod | ||
def call(self, data, steps, seq_lens, training): | ||
"""One pass through the model. | ||
Args: | ||
data: dict, batches of tensors from many videos. Available keys: 'audio', | ||
'frames', 'labels'. | ||
steps: Tensor, batch of indices of chosen frames in videos. | ||
seq_lens: Tensor, batch of sequence length of the full videos. | ||
training: Boolean, if True model is run in training mode. | ||
Returns: | ||
embeddings: Tensor, Float tensor containing embeddings | ||
Raises: | ||
ValueError: In case invalid configs are passed. | ||
""" | ||
cnn = self.model['cnn'] | ||
emb = self.model['emb'] | ||
|
||
if training: | ||
num_steps = CONFIG.TRAIN.NUM_FRAMES | ||
else: | ||
num_steps = CONFIG.EVAL.NUM_FRAMES | ||
|
||
cnn_feats = get_cnn_feats(cnn, data, training) | ||
|
||
embs = emb(cnn_feats, num_steps) | ||
channels = embs.shape[-1] | ||
embs = tf.reshape(embs, [-1, num_steps, channels]) | ||
|
||
return embs | ||
|
||
@abc.abstractmethod | ||
def compute_loss(self, embs, steps, seq_lens, global_step, training, | ||
frame_labels=None, seq_labels=None): | ||
pass | ||
|
||
def get_base_and_embedding_variables(self): | ||
"""Gets list of trainable vars from model's base and embedding networks. | ||
Returns: | ||
variables: List, list of variables we want to train. | ||
""" | ||
|
||
if CONFIG.MODEL.TRAIN_BASE == 'train_all': | ||
variables = self.model['cnn'].variables | ||
elif CONFIG.MODEL.TRAIN_BASE == 'only_bn': | ||
# TODO(debidatta): Better way to extract batch norm variables. | ||
variables = [x for x in self.model['cnn'].variables | ||
if 'batch_norm' in x.name or 'bn' in x.name] | ||
elif CONFIG.MODEL.TRAIN_BASE == 'frozen': | ||
variables = [] | ||
else: | ||
raise ValueError('train_base values supported right now: train_all, ' | ||
'only_bn or frozen.') | ||
if CONFIG.MODEL.TRAIN_EMBEDDING: | ||
variables += self.model['emb'].variables | ||
return variables | ||
|
||
@abc.abstractmethod | ||
def get_algo_variables(self): | ||
return [] | ||
|
||
@property | ||
def variables(self): | ||
"""Returns list of variables to train. | ||
Returns: | ||
variables: list, Contains variables that will be trained. | ||
""" | ||
variables = [x for x in self.get_base_and_embedding_variables() | ||
if 'moving' not in x.name] | ||
variables += [x for x in self.get_algo_variables() | ||
if 'moving' not in x.name] | ||
return variables | ||
|
||
def compute_gradients(self, loss, tape=None): | ||
"""This is to be used in Eager mode when a GradientTape is available.""" | ||
if tf.executing_eagerly(): | ||
assert tape is not None | ||
gradients = tape.gradient(loss, self.variables) | ||
else: | ||
gradients = tf.gradients(loss, self.variables) | ||
return gradients | ||
|
||
def apply_gradients(self, optimizer, grads): | ||
"""Functional style apply_grads for `tfe.defun`.""" | ||
optimizer.apply_gradients(zip(grads, self.variables)) | ||
|
||
def train_one_iter(self, data, steps, seq_lens, global_step, optimizer): | ||
with tf.GradientTape() as tape: | ||
embs = self.call(data, steps, seq_lens, training=True) | ||
loss = self.compute_loss(embs, steps, seq_lens, global_step, | ||
training=True, frame_labels=data['frame_labels'], | ||
seq_labels=data['seq_labels']) | ||
# Add regularization losses. | ||
reg_loss = tf.reduce_mean(tf.stack(self.losses)) | ||
tf.summary.scalar('reg_loss', reg_loss, step=global_step) | ||
loss += reg_loss | ||
|
||
# Be careful not to use object based losses in tf.keras.losses | ||
# (CategoricalCrossentropy) or tf.losses (softmax_cross_entropy). The | ||
# above losses scale by number of GPUs on their own which can lead to | ||
# inconsistent scaling. Hence, always use functional version losses | ||
# defined in tf.keras.losses (categorical_crossentropy. | ||
# Divide by number of replicas. | ||
strategy = tf.distribute.get_strategy() | ||
num_replicas = strategy.num_replicas_in_sync | ||
loss *= (1. / num_replicas) | ||
|
||
gradients = self.compute_gradients(loss, tape) | ||
self.apply_gradients(optimizer, gradients) | ||
|
||
if DEBUG: | ||
for v, g in zip(self.variables, gradients): | ||
norm = tf.reduce_sum(g*g) | ||
tf.summary.scalar('grad_norm_%s' % v.name, norm, | ||
step=global_step) | ||
grad_norm = tf.reduce_mean(tf.stack([tf.reduce_sum(grad * grad) | ||
for grad in gradients])) | ||
tf.summary.scalar('grad_norm', grad_norm, step=global_step) | ||
for k in self.model: | ||
for var_ in self.model[k].variables: | ||
tf.summary.histogram(var_.name, var_, step=global_step) | ||
|
||
return loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
from absl import flags | ||
|
||
from algos.algorithm import Algorithm | ||
from config import CONFIG | ||
from tcc.alignment import compute_alignment_loss | ||
|
||
FLAGS = flags.FLAGS | ||
|
||
|
||
class Alignment(Algorithm): | ||
"""Uses cycle-consistency loss to perform unsupervised training.""" | ||
|
||
def compute_loss(self, embs, steps, seq_lens, global_step, training, | ||
frame_labels, seq_labels): | ||
if training: | ||
batch_size = CONFIG.TRAIN.BATCH_SIZE | ||
num_steps = CONFIG.TRAIN.NUM_FRAMES | ||
else: | ||
batch_size = CONFIG.EVAL.BATCH_SIZE | ||
num_steps = CONFIG.EVAL.NUM_FRAMES | ||
|
||
loss = compute_alignment_loss( | ||
embs, | ||
batch_size, | ||
steps=steps, | ||
seq_lens=seq_lens, | ||
stochastic_matching=CONFIG.ALIGNMENT.STOCHASTIC_MATCHING, | ||
normalize_embeddings=False, | ||
loss_type=CONFIG.ALIGNMENT.LOSS_TYPE, | ||
similarity_type=CONFIG.ALIGNMENT.SIMILARITY_TYPE, | ||
num_cycles=int(batch_size * num_steps * CONFIG.ALIGNMENT.FRACTION), | ||
cycle_length=CONFIG.ALIGNMENT.CYCLE_LENGTH, | ||
temperature=CONFIG.ALIGNMENT.SOFTMAX_TEMPERATURE, | ||
label_smoothing=CONFIG.ALIGNMENT.LABEL_SMOOTHING, | ||
variance_lambda=CONFIG.ALIGNMENT.VARIANCE_LAMBDA, | ||
huber_delta=CONFIG.ALIGNMENT.HUBER_DELTA, | ||
normalize_indices=CONFIG.ALIGNMENT.NORMALIZE_INDICES) | ||
|
||
return loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
# coding=utf-8 | ||
# Copyright 2019 The Google Research Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Alignment + SaL+ TCN loss for unsupervised training.""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import tensorflow as tf | ||
|
||
from algos.alignment import Alignment | ||
from algos.sal import SaL | ||
from algos.tcn import TCN | ||
from config import CONFIG | ||
|
||
|
||
class AlignmentSaLTCN(TCN): | ||
"""Network trained with combination losses.""" | ||
|
||
def __init__(self, model=None): | ||
super(AlignmentSaLTCN, self).__init__(model) | ||
algo_config = CONFIG.ALIGNMENT_SAL_TCN | ||
self.alignment_loss_weight = algo_config.ALIGNMENT_LOSS_WEIGHT | ||
self.sal_loss_weight = algo_config.SAL_LOSS_WEIGHT | ||
self.tcn_loss_weight = (1.0 - self.alignment_loss_weight - | ||
self.sal_loss_weight) | ||
if self.alignment_loss_weight + self.sal_loss_weight > 1.0: | ||
raise ValueError('Sum of weights > 1 Not allowed.') | ||
if self.alignment_loss_weight < 0 or self.sal_loss_weight < 0: | ||
raise ValueError('Negative weights not allowed.') | ||
|
||
self.algos = [] | ||
if self.alignment_loss_weight > 0: | ||
self.alignment_algo = Alignment(self.model) | ||
self.algos.append(self.alignment_algo) | ||
if self.sal_loss_weight > 0: | ||
self.sal_algo = SaL(self.model) | ||
self.algos.append(self.sal_algo) | ||
if self.tcn_loss_weight > 0: | ||
self.tcn_algo = TCN(self.model) | ||
self.algos.append(self.tcn_algo) | ||
|
||
def get_algo_variables(self): | ||
algo_variables = [] | ||
for algo in self.algos: | ||
algo_variables.extend(algo.get_algo_variables()) | ||
return algo_variables | ||
|
||
def compute_loss(self, embs, steps, seq_lens, global_step, training, | ||
frame_labels, seq_labels): | ||
|
||
if self.tcn_loss_weight != 0.0: | ||
tcn_loss = self.tcn_algo.compute_loss(embs, steps, seq_lens, global_step, | ||
training, frame_labels, seq_labels) | ||
if training: | ||
tf.summary.scalar('alignment_sal_tcn/tcn_loss', tcn_loss, | ||
step=global_step) | ||
else: | ||
tcn_loss = 0.0 | ||
|
||
if self.alignment_loss_weight != 0.0 or self.sal_loss_weight != 0.0: | ||
if training: | ||
batch_size = CONFIG.TRAIN.BATCH_SIZE | ||
num_steps = CONFIG.TRAIN.NUM_FRAMES | ||
else: | ||
batch_size = CONFIG.EVAL.BATCH_SIZE | ||
num_steps = CONFIG.EVAL.NUM_FRAMES | ||
|
||
embs_list = [] | ||
steps_list = [] | ||
seq_lens_list = [] | ||
|
||
for i in range(int(batch_size)): | ||
# Randomly sample half of TCN frames as in datasets.py we already | ||
# sample double the number of frames because it requires positives for | ||
# training. | ||
chosen_steps = tf.cond(tf.random.uniform(()) < 0.5, | ||
lambda: tf.range(0, 2 * num_steps, 2), | ||
lambda: tf.range(1, 2 * num_steps, 2)) | ||
|
||
embs_ = tf.gather(embs[i], chosen_steps) | ||
steps_ = tf.gather(steps[i], chosen_steps) | ||
|
||
embs_list.append(embs_) | ||
steps_list.append(steps_) | ||
seq_lens_list.append(seq_lens[i]) | ||
|
||
embs = tf.stack(embs_list) | ||
steps = tf.stack(steps_list) | ||
seq_lens = tf.stack(seq_lens_list) | ||
|
||
if self.alignment_loss_weight != 0: | ||
alignment_loss = self.alignment_algo.compute_loss(embs, steps, seq_lens, | ||
num_steps, batch_size, | ||
global_step, training) | ||
if training: | ||
tf.summary.scalar('alignment_sal_tcn/alignment_loss', | ||
alignment_loss, step=global_step) | ||
else: | ||
alignment_loss = 0.0 | ||
|
||
if self.sal_loss_weight != 0: | ||
sal_loss = self.sal_algo.compute_loss(embs, steps, seq_lens, global_step, | ||
training, frame_labels, seq_labels) | ||
|
||
if training: | ||
tf.summary.scalar('alignment_sal_tcn/sal_loss', sal_loss, | ||
step=global_step) | ||
else: | ||
sal_loss = 0.0 | ||
|
||
return (self.alignment_loss_weight * alignment_loss + | ||
self.sal_loss_weight * sal_loss + | ||
self.tcn_loss_weight * tcn_loss) |
Oops, something went wrong.