@@ -209,15 +209,18 @@ def __init__(self,
209
209
210
210
# Create variables and predictions.
211
211
with tf .variable_scope ('predictions' ):
212
- encoding , variables , reg_params = self .model .get_encoding_and_params (
213
- inputs = input_features , is_train = is_train )
214
- self .variables = variables
215
- self .reg_params = reg_params
216
- predictions , variables , reg_params = (
212
+ encoding , variables_enc , reg_params_enc = (
213
+ self .model .get_encoding_and_params (
214
+ inputs = input_features ,
215
+ is_train = is_train ))
216
+ self .variables = variables_enc
217
+ self .reg_params = reg_params_enc
218
+ predictions , variables_pred , reg_params_pred = (
217
219
self .model .get_predictions_and_params (
218
- encoding = encoding , is_train = is_train ))
219
- self .variables .update (variables )
220
- self .reg_params .update (reg_params )
220
+ encoding = encoding ,
221
+ is_train = is_train ))
222
+ self .variables .update (variables_pred )
223
+ self .reg_params .update (reg_params_pred )
221
224
normalized_predictions = self .model .normalize_predictions (predictions )
222
225
predictions_var_scope = tf .get_variable_scope ()
223
226
@@ -262,7 +265,7 @@ def __init__(self,
262
265
# Weight decay loss.
263
266
loss_reg = 0.0
264
267
if weight_decay_var is not None :
265
- for var in reg_params .values ():
268
+ for var in self . reg_params .values ():
266
269
loss_reg += weight_decay_var * tf .nn .l2_loss (var )
267
270
268
271
# Adversarial loss, in case we want to add VAT on top of GAM.
@@ -351,7 +354,7 @@ def __init__(self,
351
354
if isinstance (weight_decay_var , tf .Variable ):
352
355
self .vars_to_save .append (weight_decay_var )
353
356
if self .warm_start :
354
- self .vars_to_save .extend ([v for v in variables ])
357
+ self .vars_to_save .extend ([v for v in self . variables ])
355
358
356
359
# More variables to be initialized after the session is created.
357
360
self .is_initialized = False
@@ -366,7 +369,6 @@ def __init__(self,
366
369
self .weight_decay_update = weight_decay_update
367
370
self .iter_cls_total = iter_cls_total
368
371
self .iter_cls_total_update = iter_cls_total_update
369
- self .variables = variables
370
372
self .accuracy = accuracy
371
373
self .train_op = train_op
372
374
self .loss_op = loss_op
0 commit comments