@@ -491,6 +491,7 @@ def _get_agreement_reg_loss(self, data, is_train, features_shape):
491
491
# edges at the end of training, so the shapes don't match needs fixing.
492
492
left = tf .concat ((labels_ll_left , labels_lu_left , predictions_uu_left ),
493
493
axis = 0 )
494
+ # left = tf.stop_gradient(left)
494
495
right = tf .concat (
495
496
(predictions_ll_right , predictions_lu_right , predictions_uu_right ),
496
497
axis = 0 )
@@ -514,13 +515,14 @@ def _get_agreement_reg_loss(self, data, is_train, features_shape):
514
515
src_indices = indices_uu_left ,
515
516
tgt_indices = indices_uu_right )
516
517
agreement = tf .concat ((agreement_ll , agreement_lu , agreement_uu ), axis = 0 )
518
+ # agreement = tf.stop_gradient(agreement)
517
519
if self .penalize_neg_agr :
518
520
# Since the agreement is predicting scores between [0, 1], anything
519
521
# under 0.5 should represent disagreement. Therefore, we want to encourage
520
522
# agreement whenever the score is > 0.5, otherwise don't incur any loss.
521
523
agreement = tf .nn .relu (agreement - 0.5 )
522
524
523
- # Create a Tensor containing the weights assigned to each pair in the
525
+ # Create a Tensor containing the weights assigned to each pair in the
524
526
# agreement regularization loss, depending on how many samples in the pair
525
527
# were labeled. This weight can be either reg_weight_ll, reg_weight_lu,
526
528
# or reg_weight_uu.
0 commit comments