Skip to content

Commit 5a99f0d

Browse files
committed
Added stop_gradients as a comment for now.
1 parent 6544d76 commit 5a99f0d

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

neural_structured_learning/research/gam/trainer/trainer_classification.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,7 @@ def _get_agreement_reg_loss(self, data, is_train, features_shape):
491491
# edges at the end of training, so the shapes don't match needs fixing.
492492
left = tf.concat((labels_ll_left, labels_lu_left, predictions_uu_left),
493493
axis=0)
494+
# left = tf.stop_gradient(left)
494495
right = tf.concat(
495496
(predictions_ll_right, predictions_lu_right, predictions_uu_right),
496497
axis=0)
@@ -514,13 +515,14 @@ def _get_agreement_reg_loss(self, data, is_train, features_shape):
514515
src_indices=indices_uu_left,
515516
tgt_indices=indices_uu_right)
516517
agreement = tf.concat((agreement_ll, agreement_lu, agreement_uu), axis=0)
518+
# agreement = tf.stop_gradient(agreement)
517519
if self.penalize_neg_agr:
518520
# Since the agreement is predicting scores between [0, 1], anything
519521
# under 0.5 should represent disagreement. Therefore, we want to encourage
520522
# agreement whenever the score is > 0.5, otherwise don't incur any loss.
521523
agreement = tf.nn.relu(agreement - 0.5)
522524

523-
# Create a Tensor containing the weights assigned to each pair in the
525+
# Create a Tensor containing the weights assigned to each pair in the
524526
# agreement regularization loss, depending on how many samples in the pair
525527
# were labeled. This weight can be either reg_weight_ll, reg_weight_lu,
526528
# or reg_weight_uu.

0 commit comments

Comments
 (0)