-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathsal.py
executable file
·97 lines (77 loc) · 3.51 KB
/
sal.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow.compat.v2 as tf
from algos.algorithm import Algorithm
from config import CONFIG
from models import Classifier
from utils import random_choice_noreplace
def randomly_reverse_indices(indices):
"""Randomly reverse the indices."""
return tf.cond(tf.random.uniform(()) < 0.5,
lambda: indices,
lambda: indices[:, ::-1])
def get_shuffled_indices_and_labels(batch_size, num_samples, shuffle_fraction,
num_steps):
"""Produce possibly shuffled indices and labels."""
total_num_samples = batch_size * num_samples
num_shuffled_examples = int(shuffle_fraction * total_num_samples)
shuffle_labels = tf.random.shuffle(tf.cast(
num_shuffled_examples*[1] +
(total_num_samples - num_shuffled_examples) * [0], tf.int32))
indices = tf.sort(random_choice_noreplace(
total_num_samples, num_steps)[:, :5], axis=1)
indices = randomly_reverse_indices(indices)
shuffled_samples = tf.where(
tf.less_equal(tf.random.uniform((total_num_samples, 1)), 0.5),
tf.gather(indices, [1, 0, 3], axis=1),
tf.gather(indices, [1, 4, 3], axis=1))
ordered_samples = tf.gather(indices, [1, 2, 3], axis=1)
indices = tf.where(tf.equal(tf.expand_dims(shuffle_labels, axis=-1), 1),
shuffled_samples, ordered_samples)
return indices, shuffle_labels
def sample_batch(embs, batch_size, num_steps):
"""Returns concatenated features and shuffle labels."""
shuffle_fraction = CONFIG.SAL.SHUFFLE_FRACTION
num_samples = CONFIG.SAL.NUM_SAMPLES
indices, labels = get_shuffled_indices_and_labels(batch_size,
num_samples,
shuffle_fraction,
num_steps)
labels = tf.one_hot(labels, 2)
labels = tf.stop_gradient(labels)
indices = tf.stop_gradient(indices)
embs = tf.tile(embs, [num_samples, 1, 1])
embs = tf.gather(embs, indices, axis=1, batch_dims=-1)
concat_embs = tf.squeeze(tf.concat(tf.split(embs, 3, axis=1), axis=-1),
axis=1)
return concat_embs, labels
class SaL(Algorithm):
"""Shuffle and Learn algorithm (https://arxiv.org/abs/1603.08561) ."""
def __init__(self, model=None):
super(SaL, self).__init__(model)
if CONFIG.SAL.FC_LAYERS[-1][0] != 2:
raise ValueError('Shuffle and Learn classifier has only 2 classes:'
'correct order or incorrect order. Ensure last layer in '
'config.sal.fc_layers is 2.')
sal_classifier = Classifier(CONFIG.SAL.FC_LAYERS, CONFIG.SAL.DROPOUT_RATE)
self.model['sal_classifier'] = sal_classifier
def get_algo_variables(self):
return self.model['sal_classifier'].variables
def compute_loss(self, embs, steps, seq_lens, global_step, training,
frame_labels, seq_labels):
if training:
batch_size = CONFIG.TRAIN.BATCH_SIZE
num_steps = CONFIG.TRAIN.NUM_FRAMES
else:
batch_size = CONFIG.EVAL.BATCH_SIZE
num_steps = CONFIG.EVAL.NUM_FRAMES
concat_embs, labels = sample_batch(embs, batch_size, num_steps)
logits = self.model['sal_classifier'](concat_embs)
loss = tf.reduce_mean(
tf.keras.losses.categorical_crossentropy(
y_true=labels,
y_pred=logits,
from_logits=True,
label_smoothing=CONFIG.SAL.LABEL_SMOOTHING))
return loss