7
7
8
8
9
9
def kl_divergence_with_logit (q_logit , p_logit ):
10
- """Computes KL-divergence between to sets of logits."""
11
- q = tf .nn .softmax (q_logit )
12
- qlogq = - tf .nn .softmax_cross_entropy_with_logits_v2 (
13
- labels = q , logits = q_logit )
14
- qlogp = - tf .nn .softmax_cross_entropy_with_logits_v2 (
15
- labels = q , logits = p_logit )
16
- return qlogq - qlogp
10
+ """Computes KL-divergence between to sets of logits."""
11
+ q = tf .nn .softmax (q_logit )
12
+ qlogq = - tf .nn .softmax_cross_entropy_with_logits_v2 (
13
+ labels = q , logits = q_logit )
14
+ qlogp = - tf .nn .softmax_cross_entropy_with_logits_v2 (
15
+ labels = q , logits = p_logit )
16
+ return qlogq - qlogp
17
17
18
18
19
19
def get_normalized_vector (d ):
20
- d /= (1e-12 + tf .reduce_max (tf .abs (d ), keep_dims = True ))
21
- d /= tf .sqrt (1e-6 + tf .reduce_sum (tf .pow (d , 2.0 ), keep_dims = True ))
22
- return d
20
+ """Normalizes the providede input vector."""
21
+ d /= (1e-12 + tf .reduce_max (tf .abs (d ), keep_dims = True ))
22
+ d /= tf .sqrt (1e-6 + tf .reduce_sum (tf .pow (d , 2.0 ), keep_dims = True ))
23
+ return d
23
24
24
25
25
26
def get_normalizing_constant (d ):
26
- c = 1e-12 + tf .reduce_max (tf .abs (d ), keep_dims = True )
27
- c *= tf .sqrt (1e-6 + tf .reduce_sum (tf .pow (d , 2.0 ), keep_dims = True ))
28
- return c
27
+ """Returns the normalizing constant to scale the VAT perturbation vector."""
28
+ c = 1e-12 + tf .reduce_max (tf .abs (d ), keep_dims = True )
29
+ c *= tf .sqrt (1e-6 + tf .reduce_sum (tf .pow (d , 2.0 ), keep_dims = True ))
30
+ return c
29
31
30
32
31
33
def get_loss_vat (inputs , predictions , is_train , model , predictions_var_scope ):
32
- """Computes the virtual adversarial loss for the provided inputs.
33
-
34
- Args:
35
- inputs: A batch of input features, where the batch is the first
36
- dimension.
37
- predictions: The logits predicted by a model on the provided inputs.
38
- is_train: A boolean placeholder specifying if this is a training or
39
- testing setting.
40
- model: The model that generated the logits.
41
- predictions_var_scope: Variable scope for obtaining the predictions.
42
- Returns:
43
- A float value representing the virtual adversarial loss.
44
- """
45
- r_vadv = generate_virtual_adversarial_perturbation (
46
- inputs , predictions , model , predictions_var_scope , is_train = is_train )
47
- predictions = tf .stop_gradient (predictions )
48
- logit_p = predictions
49
- new_inputs = tf .add (inputs , r_vadv )
50
- with tf .variable_scope (
51
- predictions_var_scope , auxiliary_name_scope = False , reuse = True ):
52
- encoding_m , _ , _ = model .get_encoding_and_params (
53
- inputs = new_inputs ,
54
- is_train = is_train ,
55
- update_batch_stats = False )
56
- logit_m , _ , _ = model .get_predictions_and_params (
57
- encoding = encoding_m ,
58
- is_train = is_train )
59
- loss = kl_divergence_with_logit (logit_p , logit_m )
60
- return tf .reduce_mean (loss )
34
+ """Computes the virtual adversarial loss for the provided inputs.
35
+
36
+ Args:
37
+ inputs: A batch of input features, where the batch is the first
38
+ dimension.
39
+ predictions: The logits predicted by a model on the provided inputs.
40
+ is_train: A boolean placeholder specifying if this is a training or
41
+ testing setting.
42
+ model: The model that generated the logits.
43
+ predictions_var_scope: Variable scope for obtaining the predictions.
44
+ Returns:
45
+ A float value representing the virtual adversarial loss.
46
+ """
47
+ r_vadv = generate_virtual_adversarial_perturbation (
48
+ inputs , predictions , model , predictions_var_scope , is_train = is_train )
49
+ predictions = tf .stop_gradient (predictions )
50
+ logit_p = predictions
51
+ new_inputs = tf .add (inputs , r_vadv )
52
+ with tf .variable_scope (
53
+ predictions_var_scope , auxiliary_name_scope = False , reuse = True ):
54
+ encoding_m , _ , _ = model .get_encoding_and_params (
55
+ inputs = new_inputs ,
56
+ is_train = is_train ,
57
+ update_batch_stats = False )
58
+ logit_m , _ , _ = model .get_predictions_and_params (
59
+ encoding = encoding_m ,
60
+ is_train = is_train )
61
+ loss = kl_divergence_with_logit (logit_p , logit_m )
62
+ return tf .reduce_mean (loss )
61
63
62
64
63
65
def generate_virtual_adversarial_perturbation (
64
66
inputs , logits , model , predictions_var_scope , is_train = True ):
65
- """Generates an adversarial perturbation for virtual adversarial training.
66
-
67
- Args:
68
- inputs: A batch of input features, where the batch is the first
69
- dimension.
70
- logits: The logits predicted by a model on the provided inputs.
71
- model: The model that generated the logits.
72
- predictions_var_scope: Variable scope for obtaining the predictions.
73
- is_train: A boolean placeholder specifying if this is a training or
74
- testing setting.
75
-
76
- Returns:
77
- A Tensor of the same shape as the inputs containing the adversarial
78
- perturbation for these inputs.
79
- """
80
- d = tf .random_normal (shape = tf .shape (inputs ))
81
-
82
- for _ in range (num_power_iterations ):
83
- d = xi * get_normalized_vector (d )
84
- logit_p = logits
85
- with tf .variable_scope (
86
- predictions_var_scope , auxiliary_name_scope = False , reuse = True ):
87
- encoding_m , _ , _ = model .get_encoding_and_params (
88
- inputs = d + inputs ,
89
- is_train = is_train ,
90
- update_batch_stats = False )
91
- logit_m , _ , _ = model .get_predictions_and_params (
92
- encoding = encoding_m ,
93
- is_train = is_train )
94
- dist = kl_divergence_with_logit (logit_p , logit_m )
95
- grad = tf .gradients (dist , [d ], aggregation_method = 2 )[0 ]
96
- d = tf .stop_gradient (grad )
97
-
98
- r_vadv = get_normalized_vector (d )
99
- if scale_r :
100
- r_vadv *= get_normalizing_constant (inputs )
101
- r_vadv *= epsilon
102
- return r_vadv
103
-
104
-
105
- def logsoftmax ( x ):
106
- """Implementation of softmax when the inputs are logits."""
107
- xdev = x - tf . reduce_max ( x , 1 , keep_dims = True )
108
- lsm = xdev - tf . log ( tf . reduce_sum ( tf . exp ( xdev ), 1 , keep_dims = True ))
109
- return lsm
110
-
111
-
112
- def entropy_y_x ( logit ):
113
- """Entropy term to add to VATENT. """
114
- p = tf .nn .softmax (logit )
115
- return tf .reduce_mean (
116
- tf .nn .softmax_cross_entropy_with_logits_v2 (labels = p , logits = logit ))
67
+ """Generates an adversarial perturbation for virtual adversarial training.
68
+
69
+ Args:
70
+ inputs: A batch of input features, where the batch is the first
71
+ dimension.
72
+ logits: The logits predicted by a model on the provided inputs.
73
+ model: The model that generated the logits.
74
+ predictions_var_scope: Variable scope for obtaining the predictions.
75
+ is_train: A boolean placeholder specifying if this is a training or
76
+ testing setting.
77
+
78
+ Returns:
79
+ A Tensor of the same shape as the inputs containing the adversarial
80
+ perturbation for these inputs.
81
+ """
82
+ d = tf .random_normal (shape = tf .shape (inputs ))
83
+
84
+ for _ in range (num_power_iterations ):
85
+ d = xi * get_normalized_vector (d )
86
+ logit_p = logits
87
+ with tf .variable_scope (
88
+ predictions_var_scope , auxiliary_name_scope = False , reuse = True ):
89
+ encoding_m , _ , _ = model .get_encoding_and_params (
90
+ inputs = d + inputs ,
91
+ is_train = is_train ,
92
+ update_batch_stats = False )
93
+ logit_m , _ , _ = model .get_predictions_and_params (
94
+ encoding = encoding_m ,
95
+ is_train = is_train )
96
+ dist = kl_divergence_with_logit (logit_p , logit_m )
97
+ grad = tf .gradients (dist , [d ], aggregation_method = 2 )[0 ]
98
+ d = tf .stop_gradient (grad )
99
+
100
+ r_vadv = get_normalized_vector (d )
101
+ if scale_r :
102
+ r_vadv *= get_normalizing_constant (inputs )
103
+ r_vadv *= epsilon
104
+ return r_vadv
105
+
106
+
107
+ def entropy_y_x ( logits ):
108
+ """Entropy term to add to VAT with entropy minimization.
109
+
110
+ Args:
111
+ logits: A Tensor containing the predicted logits for a batch of samples.
112
+
113
+ Returns:
114
+ The entropy minimization loss.
115
+ """
116
+ p = tf .nn .softmax (logits )
117
+ return tf .reduce_mean (
118
+ tf .nn .softmax_cross_entropy_with_logits_v2 (labels = p , logits = logits ))
0 commit comments