Skip to content

Commit 23c7d17

Browse files
committed
Fix issue with regularization parameters not being passed.
1 parent 5a99f0d commit 23c7d17

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

neural_structured_learning/research/gam/trainer/trainer_classification.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -209,15 +209,18 @@ def __init__(self,
209209

210210
# Create variables and predictions.
211211
with tf.variable_scope('predictions'):
212-
encoding, variables, reg_params = self.model.get_encoding_and_params(
213-
inputs=input_features, is_train=is_train)
214-
self.variables = variables
215-
self.reg_params = reg_params
216-
predictions, variables, reg_params = (
212+
encoding, variables_enc, reg_params_enc = (
213+
self.model.get_encoding_and_params(
214+
inputs=input_features,
215+
is_train=is_train))
216+
self.variables = variables_enc
217+
self.reg_params = reg_params_enc
218+
predictions, variables_pred, reg_params_pred = (
217219
self.model.get_predictions_and_params(
218-
encoding=encoding, is_train=is_train))
219-
self.variables.update(variables)
220-
self.reg_params.update(reg_params)
220+
encoding=encoding,
221+
is_train=is_train))
222+
self.variables.update(variables_pred)
223+
self.reg_params.update(reg_params_pred)
221224
normalized_predictions = self.model.normalize_predictions(predictions)
222225
predictions_var_scope = tf.get_variable_scope()
223226

@@ -262,7 +265,7 @@ def __init__(self,
262265
# Weight decay loss.
263266
loss_reg = 0.0
264267
if weight_decay_var is not None:
265-
for var in reg_params.values():
268+
for var in self.reg_params.values():
266269
loss_reg += weight_decay_var * tf.nn.l2_loss(var)
267270

268271
# Adversarial loss, in case we want to add VAT on top of GAM.
@@ -351,7 +354,7 @@ def __init__(self,
351354
if isinstance(weight_decay_var, tf.Variable):
352355
self.vars_to_save.append(weight_decay_var)
353356
if self.warm_start:
354-
self.vars_to_save.extend([v for v in variables])
357+
self.vars_to_save.extend([v for v in self.variables])
355358

356359
# More variables to be initialized after the session is created.
357360
self.is_initialized = False
@@ -366,7 +369,6 @@ def __init__(self,
366369
self.weight_decay_update = weight_decay_update
367370
self.iter_cls_total = iter_cls_total
368371
self.iter_cls_total_update = iter_cls_total_update
369-
self.variables = variables
370372
self.accuracy = accuracy
371373
self.train_op = train_op
372374
self.loss_op = loss_op

0 commit comments

Comments
 (0)