Skip to content

Commit 25d945a

Browse files
committed
Update documentation and indentation.
1 parent 006e48e commit 25d945a

File tree

1 file changed

+96
-94
lines changed

1 file changed

+96
-94
lines changed

neural_structured_learning/research/gam/trainer/adversarial.py

Lines changed: 96 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -7,110 +7,112 @@
77

88

99
def kl_divergence_with_logit(q_logit, p_logit):
10-
"""Computes KL-divergence between to sets of logits."""
11-
q = tf.nn.softmax(q_logit)
12-
qlogq = -tf.nn.softmax_cross_entropy_with_logits_v2(
13-
labels=q, logits=q_logit)
14-
qlogp = -tf.nn.softmax_cross_entropy_with_logits_v2(
15-
labels=q, logits=p_logit)
16-
return qlogq - qlogp
10+
"""Computes KL-divergence between to sets of logits."""
11+
q = tf.nn.softmax(q_logit)
12+
qlogq = -tf.nn.softmax_cross_entropy_with_logits_v2(
13+
labels=q, logits=q_logit)
14+
qlogp = -tf.nn.softmax_cross_entropy_with_logits_v2(
15+
labels=q, logits=p_logit)
16+
return qlogq - qlogp
1717

1818

1919
def get_normalized_vector(d):
20-
d /= (1e-12 + tf.reduce_max(tf.abs(d), keep_dims=True))
21-
d /= tf.sqrt(1e-6 + tf.reduce_sum(tf.pow(d, 2.0), keep_dims=True))
22-
return d
20+
"""Normalizes the providede input vector."""
21+
d /= (1e-12 + tf.reduce_max(tf.abs(d), keep_dims=True))
22+
d /= tf.sqrt(1e-6 + tf.reduce_sum(tf.pow(d, 2.0), keep_dims=True))
23+
return d
2324

2425

2526
def get_normalizing_constant(d):
26-
c = 1e-12 + tf.reduce_max(tf.abs(d), keep_dims=True)
27-
c *= tf.sqrt(1e-6 + tf.reduce_sum(tf.pow(d, 2.0), keep_dims=True))
28-
return c
27+
"""Returns the normalizing constant to scale the VAT perturbation vector."""
28+
c = 1e-12 + tf.reduce_max(tf.abs(d), keep_dims=True)
29+
c *= tf.sqrt(1e-6 + tf.reduce_sum(tf.pow(d, 2.0), keep_dims=True))
30+
return c
2931

3032

3133
def get_loss_vat(inputs, predictions, is_train, model, predictions_var_scope):
32-
"""Computes the virtual adversarial loss for the provided inputs.
33-
34-
Args:
35-
inputs: A batch of input features, where the batch is the first
36-
dimension.
37-
predictions: The logits predicted by a model on the provided inputs.
38-
is_train: A boolean placeholder specifying if this is a training or
39-
testing setting.
40-
model: The model that generated the logits.
41-
predictions_var_scope: Variable scope for obtaining the predictions.
42-
Returns:
43-
A float value representing the virtual adversarial loss.
44-
"""
45-
r_vadv = generate_virtual_adversarial_perturbation(
46-
inputs, predictions, model, predictions_var_scope, is_train=is_train)
47-
predictions = tf.stop_gradient(predictions)
48-
logit_p = predictions
49-
new_inputs = tf.add(inputs, r_vadv)
50-
with tf.variable_scope(
51-
predictions_var_scope, auxiliary_name_scope=False, reuse=True):
52-
encoding_m, _, _ = model.get_encoding_and_params(
53-
inputs=new_inputs,
54-
is_train=is_train,
55-
update_batch_stats=False)
56-
logit_m, _, _ = model.get_predictions_and_params(
57-
encoding=encoding_m,
58-
is_train=is_train)
59-
loss = kl_divergence_with_logit(logit_p, logit_m)
60-
return tf.reduce_mean(loss)
34+
"""Computes the virtual adversarial loss for the provided inputs.
35+
36+
Args:
37+
inputs: A batch of input features, where the batch is the first
38+
dimension.
39+
predictions: The logits predicted by a model on the provided inputs.
40+
is_train: A boolean placeholder specifying if this is a training or
41+
testing setting.
42+
model: The model that generated the logits.
43+
predictions_var_scope: Variable scope for obtaining the predictions.
44+
Returns:
45+
A float value representing the virtual adversarial loss.
46+
"""
47+
r_vadv = generate_virtual_adversarial_perturbation(
48+
inputs, predictions, model, predictions_var_scope, is_train=is_train)
49+
predictions = tf.stop_gradient(predictions)
50+
logit_p = predictions
51+
new_inputs = tf.add(inputs, r_vadv)
52+
with tf.variable_scope(
53+
predictions_var_scope, auxiliary_name_scope=False, reuse=True):
54+
encoding_m, _, _ = model.get_encoding_and_params(
55+
inputs=new_inputs,
56+
is_train=is_train,
57+
update_batch_stats=False)
58+
logit_m, _, _ = model.get_predictions_and_params(
59+
encoding=encoding_m,
60+
is_train=is_train)
61+
loss = kl_divergence_with_logit(logit_p, logit_m)
62+
return tf.reduce_mean(loss)
6163

6264

6365
def generate_virtual_adversarial_perturbation(
6466
inputs, logits, model, predictions_var_scope, is_train=True):
65-
"""Generates an adversarial perturbation for virtual adversarial training.
66-
67-
Args:
68-
inputs: A batch of input features, where the batch is the first
69-
dimension.
70-
logits: The logits predicted by a model on the provided inputs.
71-
model: The model that generated the logits.
72-
predictions_var_scope: Variable scope for obtaining the predictions.
73-
is_train: A boolean placeholder specifying if this is a training or
74-
testing setting.
75-
76-
Returns:
77-
A Tensor of the same shape as the inputs containing the adversarial
78-
perturbation for these inputs.
79-
"""
80-
d = tf.random_normal(shape=tf.shape(inputs))
81-
82-
for _ in range(num_power_iterations):
83-
d = xi * get_normalized_vector(d)
84-
logit_p = logits
85-
with tf.variable_scope(
86-
predictions_var_scope, auxiliary_name_scope=False, reuse=True):
87-
encoding_m, _, _ = model.get_encoding_and_params(
88-
inputs=d + inputs,
89-
is_train=is_train,
90-
update_batch_stats=False)
91-
logit_m, _, _ = model.get_predictions_and_params(
92-
encoding=encoding_m,
93-
is_train=is_train)
94-
dist = kl_divergence_with_logit(logit_p, logit_m)
95-
grad = tf.gradients(dist, [d], aggregation_method=2)[0]
96-
d = tf.stop_gradient(grad)
97-
98-
r_vadv = get_normalized_vector(d)
99-
if scale_r:
100-
r_vadv *= get_normalizing_constant(inputs)
101-
r_vadv *= epsilon
102-
return r_vadv
103-
104-
105-
def logsoftmax(x):
106-
"""Implementation of softmax when the inputs are logits."""
107-
xdev = x - tf.reduce_max(x, 1, keep_dims=True)
108-
lsm = xdev - tf.log(tf.reduce_sum(tf.exp(xdev), 1, keep_dims=True))
109-
return lsm
110-
111-
112-
def entropy_y_x(logit):
113-
"""Entropy term to add to VATENT."""
114-
p = tf.nn.softmax(logit)
115-
return tf.reduce_mean(
116-
tf.nn.softmax_cross_entropy_with_logits_v2(labels=p, logits=logit))
67+
"""Generates an adversarial perturbation for virtual adversarial training.
68+
69+
Args:
70+
inputs: A batch of input features, where the batch is the first
71+
dimension.
72+
logits: The logits predicted by a model on the provided inputs.
73+
model: The model that generated the logits.
74+
predictions_var_scope: Variable scope for obtaining the predictions.
75+
is_train: A boolean placeholder specifying if this is a training or
76+
testing setting.
77+
78+
Returns:
79+
A Tensor of the same shape as the inputs containing the adversarial
80+
perturbation for these inputs.
81+
"""
82+
d = tf.random_normal(shape=tf.shape(inputs))
83+
84+
for _ in range(num_power_iterations):
85+
d = xi * get_normalized_vector(d)
86+
logit_p = logits
87+
with tf.variable_scope(
88+
predictions_var_scope, auxiliary_name_scope=False, reuse=True):
89+
encoding_m, _, _ = model.get_encoding_and_params(
90+
inputs=d + inputs,
91+
is_train=is_train,
92+
update_batch_stats=False)
93+
logit_m, _, _ = model.get_predictions_and_params(
94+
encoding=encoding_m,
95+
is_train=is_train)
96+
dist = kl_divergence_with_logit(logit_p, logit_m)
97+
grad = tf.gradients(dist, [d], aggregation_method=2)[0]
98+
d = tf.stop_gradient(grad)
99+
100+
r_vadv = get_normalized_vector(d)
101+
if scale_r:
102+
r_vadv *= get_normalizing_constant(inputs)
103+
r_vadv *= epsilon
104+
return r_vadv
105+
106+
107+
def entropy_y_x(logits):
108+
"""Entropy term to add to VAT with entropy minimization.
109+
110+
Args:
111+
logits: A Tensor containing the predicted logits for a batch of samples.
112+
113+
Returns:
114+
The entropy minimization loss.
115+
"""
116+
p = tf.nn.softmax(logits)
117+
return tf.reduce_mean(
118+
tf.nn.softmax_cross_entropy_with_logits_v2(labels=p, logits=logits))

0 commit comments

Comments
 (0)