Skip to content

Commit 4650b68

Browse files
Merge pull request #29 from otiliastr:vat
PiperOrigin-RevId: 274596242
2 parents e63a9e7 + 25d945a commit 4650b68

File tree

8 files changed

+253
-31
lines changed

8 files changed

+253
-31
lines changed

neural_structured_learning/research/gam/data/dataset.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import logging
1717
import os
1818
import pickle
19+
1920
from gam.data.preprocessing import split_train_val
2021
import numpy as np
2122
import scipy
@@ -267,7 +268,6 @@ def _agreement_cond(edge):
267268
return self.get_labels(edge.src) == self.get_labels(edge.tgt)
268269

269270
agreement_cond = _agreement_cond if label_must_match else lambda e: True
270-
271271
return [
272272
e for e in self.edges if _labeled_cond(e.src, src_labeled) and
273273
_labeled_cond(e.tgt, tgt_labeled) and agreement_cond(e)
@@ -286,7 +286,6 @@ def __init__(self,
286286
test_mask,
287287
labels,
288288
row_normalize=False):
289-
290289
# Extract train, val, test, unlabeled indices.
291290
train_indices = np.where(train_mask)[0]
292291
test_indices = np.where(test_mask)[0]

neural_structured_learning/research/gam/data/loaders.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import logging
2222
import os
2323
import pickle
24+
2425
import sys
2526

2627
from gam.data.dataset import Dataset

neural_structured_learning/research/gam/experiments/run_train_gam.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,17 +41,21 @@
4141

4242
FLAGS = flags.FLAGS
4343
flags.DEFINE_string(
44-
'dataset_name', '',
44+
'dataset_name', 'cifar10',
4545
'Dataset name. Supported options are: mnist, cifar10, cifar100, '
4646
'svhn_cropped, fashion_mnist.')
4747
flags.DEFINE_string(
4848
'data_source', 'tensorflow_datasets', 'Data source. Valid options are: '
4949
'`tensorflow_datasets`, `realistic_ssl`, `planetoid`.')
50-
flags.DEFINE_integer('target_num_train_per_class', 400,
51-
'Number of samples per class to use for training.')
52-
flags.DEFINE_integer('target_num_val', 1000,
53-
'Number of samples to be used for validation.')
54-
flags.DEFINE_integer('seed', 123, 'Seed used by the random number generators.')
50+
flags.DEFINE_integer(
51+
'target_num_train_per_class', 400,
52+
'Number of samples per class to use for training.')
53+
flags.DEFINE_integer(
54+
'target_num_val', 1000,
55+
'Number of samples to be used for validation.')
56+
flags.DEFINE_integer(
57+
'seed', 123,
58+
'Seed used by the random number generators.')
5559
flags.DEFINE_bool(
5660
'load_preprocessed', False,
5761
'Specifies whether to load data already preprocessed. If False, it reads'
@@ -222,6 +226,12 @@
222226
'num_pairs_reg', 128,
223227
'Number of pairs of nodes to use in the agreement loss term of the '
224228
'classification model.')
229+
flags.DEFINE_float(
230+
'reg_weight_vat', 0.0,
231+
'Regularization weight for the virtual adversarial training (VAT) loss.')
232+
flags.DEFINE_bool(
233+
'use_ent_min', False,
234+
'A boolean specifying whether to add entropy minimization to VAT.')
225235
flags.DEFINE_string(
226236
'aggregation_agr_inputs', 'dist',
227237
'Operation to apply on the pair of nodes in the agreement model. '
@@ -421,6 +431,8 @@ def main(argv):
421431
reg_weight_ll=FLAGS.reg_weight_ll,
422432
reg_weight_lu=FLAGS.reg_weight_lu,
423433
reg_weight_uu=FLAGS.reg_weight_uu,
434+
reg_weight_vat=FLAGS.reg_weight_vat,
435+
use_ent_min=FLAGS.use_ent_min,
424436
num_pairs_reg=FLAGS.num_pairs_reg,
425437
penalize_neg_agr=FLAGS.penalize_neg_agr,
426438
use_l2_cls=FLAGS.use_l2_cls,

neural_structured_learning/research/gam/experiments/run_train_gam_graph.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import numpy as np
3737
import tensorflow as tf
3838

39+
3940
FLAGS = flags.FLAGS
4041
flags.DEFINE_string(
4142
'dataset_name', 'cora',
@@ -196,6 +197,12 @@
196197
'num_pairs_reg', 128,
197198
'Number of pairs of nodes to use in the agreement loss term of the '
198199
'classification model.')
200+
flags.DEFINE_float(
201+
'reg_weight_vat', 0.0,
202+
'Regularization weight for the virtual adversarial training (VAT) loss.')
203+
flags.DEFINE_bool(
204+
'use_ent_min', False,
205+
'A boolean specifying whether to add entropy minimization to VAT.')
199206
flags.DEFINE_string(
200207
'aggregation_agr_inputs', 'dist',
201208
'Operation to apply on the pair of nodes in the agreement model. '
@@ -291,6 +298,8 @@ def main(argv):
291298
model_name += '-perfCls' if FLAGS.use_perfect_classifier else ''
292299
model_name += '-keepProp' if FLAGS.keep_label_proportions else ''
293300
model_name += '-PenNegAgr' if FLAGS.penalize_neg_agr else ''
301+
model_name += '-VAT' if FLAGS.reg_weight_vat > 0 else ''
302+
model_name += 'ENT' if FLAGS.reg_weight_vat > 0 and FLAGS.use_ent_min else ''
294303
model_name += '-transd' if not FLAGS.inductive else ''
295304
model_name += '-L2' if FLAGS.use_l2_cls else '-CE'
296305
model_name += '-graph' if FLAGS.use_graph else '-noGraph'
@@ -380,6 +389,8 @@ def main(argv):
380389
reg_weight_lu=FLAGS.reg_weight_lu,
381390
reg_weight_uu=FLAGS.reg_weight_uu,
382391
num_pairs_reg=FLAGS.num_pairs_reg,
392+
reg_weight_vat=FLAGS.reg_weight_vat,
393+
use_ent_min=FLAGS.use_ent_min,
383394
penalize_neg_agr=FLAGS.penalize_neg_agr,
384395
use_l2_cls=FLAGS.use_l2_cls,
385396
first_iter_original=FLAGS.first_iter_original,
@@ -401,6 +412,5 @@ def main(argv):
401412
############################################################################
402413
trainer.train(data)
403414

404-
405415
if __name__ == '__main__':
406416
app.run(main)
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Copyright 2019 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Utilities for virtual adversarial training."""
15+
import tensorflow as tf
16+
17+
epsilon = 5
18+
num_power_iterations = 1
19+
xi = 1e-6
20+
scale_r = False
21+
22+
23+
def kl_divergence_with_logit(q_logit, p_logit):
24+
"""Computes KL-divergence between to sets of logits."""
25+
q = tf.nn.softmax(q_logit)
26+
qlogq = -tf.nn.softmax_cross_entropy_with_logits_v2(labels=q, logits=q_logit)
27+
qlogp = -tf.nn.softmax_cross_entropy_with_logits_v2(labels=q, logits=p_logit)
28+
return qlogq - qlogp
29+
30+
31+
def get_normalized_vector(d):
32+
"""Normalizes the providede input vector."""
33+
d /= (1e-12 + tf.reduce_max(tf.abs(d), keep_dims=True))
34+
d /= tf.sqrt(1e-6 + tf.reduce_sum(tf.pow(d, 2.0), keep_dims=True))
35+
return d
36+
37+
38+
def get_normalizing_constant(d):
39+
"""Returns the normalizing constant to scale the VAT perturbation vector."""
40+
c = 1e-12 + tf.reduce_max(tf.abs(d), keep_dims=True)
41+
c *= tf.sqrt(1e-6 + tf.reduce_sum(tf.pow(d, 2.0), keep_dims=True))
42+
return c
43+
44+
45+
def get_loss_vat(inputs, predictions, is_train, model, predictions_var_scope):
46+
"""Computes the virtual adversarial loss for the provided inputs.
47+
48+
Args:
49+
inputs: A batch of input features, where the batch is the first dimension.
50+
predictions: The logits predicted by a model on the provided inputs.
51+
is_train: A boolean placeholder specifying if this is a training or testing
52+
setting.
53+
model: The model that generated the logits.
54+
predictions_var_scope: Variable scope for obtaining the predictions.
55+
56+
Returns:
57+
A float value representing the virtual adversarial loss.
58+
"""
59+
r_vadv = generate_virtual_adversarial_perturbation(
60+
inputs, predictions, model, predictions_var_scope, is_train=is_train)
61+
predictions = tf.stop_gradient(predictions)
62+
logit_p = predictions
63+
new_inputs = tf.add(inputs, r_vadv)
64+
with tf.variable_scope(
65+
predictions_var_scope, auxiliary_name_scope=False, reuse=True):
66+
encoding_m, _, _ = model.get_encoding_and_params(
67+
inputs=new_inputs, is_train=is_train, update_batch_stats=False)
68+
logit_m, _, _ = model.get_predictions_and_params(
69+
encoding=encoding_m, is_train=is_train)
70+
loss = kl_divergence_with_logit(logit_p, logit_m)
71+
return tf.reduce_mean(loss)
72+
73+
74+
def generate_virtual_adversarial_perturbation(inputs,
75+
logits,
76+
model,
77+
predictions_var_scope,
78+
is_train=True):
79+
"""Generates an adversarial perturbation for virtual adversarial training.
80+
81+
Args:
82+
inputs: A batch of input features, where the batch is the first dimension.
83+
logits: The logits predicted by a model on the provided inputs.
84+
model: The model that generated the logits.
85+
predictions_var_scope: Variable scope for obtaining the predictions.
86+
is_train: A boolean placeholder specifying if this is a training or testing
87+
setting.
88+
89+
Returns:
90+
A Tensor of the same shape as the inputs containing the adversarial
91+
perturbation for these inputs.
92+
"""
93+
d = tf.random_normal(shape=tf.shape(inputs))
94+
95+
for _ in range(num_power_iterations):
96+
d = xi * get_normalized_vector(d)
97+
logit_p = logits
98+
with tf.variable_scope(
99+
predictions_var_scope, auxiliary_name_scope=False, reuse=True):
100+
encoding_m, _, _ = model.get_encoding_and_params(
101+
inputs=d + inputs, is_train=is_train, update_batch_stats=False)
102+
logit_m, _, _ = model.get_predictions_and_params(
103+
encoding=encoding_m, is_train=is_train)
104+
dist = kl_divergence_with_logit(logit_p, logit_m)
105+
grad = tf.gradients(dist, [d], aggregation_method=2)[0]
106+
d = tf.stop_gradient(grad)
107+
108+
r_vadv = get_normalized_vector(d)
109+
if scale_r:
110+
r_vadv *= get_normalizing_constant(inputs)
111+
r_vadv *= epsilon
112+
return r_vadv
113+
114+
115+
def entropy_y_x(logits):
116+
"""Entropy term to add to VAT with entropy minimization.
117+
118+
Args:
119+
logits: A Tensor containing the predicted logits for a batch of samples.
120+
121+
Returns:
122+
The entropy minimization loss.
123+
"""
124+
p = tf.nn.softmax(logits)
125+
return tf.reduce_mean(
126+
tf.nn.softmax_cross_entropy_with_logits_v2(labels=p, logits=logits))

neural_structured_learning/research/gam/trainer/trainer_agreement.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,9 @@ def _eval_train(self, session, feed_dict):
442442
feed_dict: A train feed dictionary.
443443
444444
Returns:
445-
The computed train accuracy.
445+
train_acc: The computed train accuracy.
446+
acc_0: Accuracy for class 0.
447+
acc_1: Accuracy for class 1.
446448
"""
447449
train_acc, pred, targ = session.run(
448450
(self.accuracy, self.normalized_predictions, self.labels),
@@ -462,9 +464,7 @@ def _eval_train(self, session, feed_dict):
462464
acc_0 = sum(acc_0) / np.float32(len(acc_0))
463465
else:
464466
acc_0 = -1
465-
logging.info('Train acc: %.2f. Acc class 1: %.2f. Acc class 0: %.2f',
466-
train_acc, acc_1, acc_0)
467-
return train_acc
467+
return train_acc, acc_0, acc_1
468468

469469
def _eval_validation(self, data_iterator_val, num_samples_val, session):
470470
"""Evaluate the current model on validation data.
@@ -685,7 +685,8 @@ def train(self, data, session=None, **kwargs):
685685
# Evaluate the accuracy on the latest train batch. We track this to make
686686
# sure the agreement model is able to fit the training data, but can be
687687
# eliminated if efficiency is an issue.
688-
acc_train = self._eval_train(session, feed_dict)
688+
acc_train, acc_0_train, acc_1_train = self._eval_train(
689+
session, feed_dict)
689690

690691
if self.enable_summaries:
691692
summary = tf.Summary()
@@ -700,9 +701,10 @@ def train(self, data, session=None, **kwargs):
700701
summary_writer.flush()
701702
if step % self.logging_step == 0 or val_acc > best_val_acc:
702703
logging.info(
703-
'Agreement step %6d | Loss: %10.4f | val_acc: %10.4f |'
704-
'random_acc: %10.4f | acc_train: %10.4f', step, loss_val, val_acc,
705-
acc_random, acc_train)
704+
'Agreement step %6d | Loss: %10.4f | val_acc: %.4f |'
705+
'random_acc: %.4f | acc_train: %.4f | acc_train_cls_0: %.4f | '
706+
'acc_train_cls_1: %.4f', step, loss_val, val_acc, acc_random,
707+
acc_train, acc_0_train, acc_1_train)
706708
if val_acc > best_val_acc:
707709
best_val_acc = val_acc
708710
if self.checkpoint_path:

0 commit comments

Comments
 (0)