Skip to content

Commit 006e48e

Browse files
committed
Merge remote-tracking branch 'upstream/master' into vat
# Conflicts: # neural_structured_learning/research/gam/data/dataset.py # neural_structured_learning/research/gam/data/loaders.py # neural_structured_learning/research/gam/experiments/helper.py # neural_structured_learning/research/gam/experiments/run_train_gam.py # neural_structured_learning/research/gam/experiments/run_train_gam_graph.py # neural_structured_learning/research/gam/trainer/trainer_agreement.py # neural_structured_learning/research/gam/trainer/trainer_classification.py # neural_structured_learning/research/gam/trainer/trainer_cotrain.py
1 parent 1798069 commit 006e48e

File tree

2 files changed

+8
-34
lines changed

2 files changed

+8
-34
lines changed

neural_structured_learning/research/gam/data/preprocessing.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -28,30 +28,6 @@ def convert_image(image):
2828
image *= 2.
2929
return image
3030

31-
def split_train_val(indices, ratio_val, rng, max_num_val=None):
32-
"""Split the train sample indices into train and validation.
33-
34-
Args:
35-
indices: A numpy array containing the indices of the training samples.
36-
ratio_val: A float number between (0, 1) representing the ratio of samples
37-
to use for validation.
38-
rng: A random number generator.
39-
max_num_val: An integer representing the maximum number of samples to
40-
include in the validation set.
41-
42-
Returns:
43-
Two numpy arrays containing the subset of indices used for training, and
44-
validation, respectively.
45-
"""
46-
num_samples = indices.shape[0]
47-
num_val = int(ratio_val * num_samples)
48-
if max_num_val and num_val > max_num_val:
49-
num_val = max_num_val
50-
ind = np.arange(0, num_samples)
51-
rng.shuffle(ind)
52-
ind_val = ind[:num_val]
53-
ind_train = ind[num_val:]
54-
return ind_train, ind_val
5531

5632
def split_train_val(indices, ratio_val, rng, max_num_val=None):
5733
"""Split the train sample indices into train and validation.

neural_structured_learning/research/gam/trainer/trainer_classification.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,6 @@ def _get_agreement_reg_loss(self, data, is_train, features_shape):
493493
# edges at the end of training, so the shapes don't match needs fixing.
494494
left = tf.concat((labels_ll_left, labels_lu_left, predictions_uu_left),
495495
axis=0)
496-
# left = tf.stop_gradient(left)
497496
right = tf.concat(
498497
(predictions_ll_right, predictions_lu_right, predictions_uu_right),
499498
axis=0)
@@ -517,7 +516,6 @@ def _get_agreement_reg_loss(self, data, is_train, features_shape):
517516
src_indices=indices_uu_left,
518517
tgt_indices=indices_uu_right)
519518
agreement = tf.concat((agreement_ll, agreement_lu, agreement_uu), axis=0)
520-
# agreement = tf.stop_gradient(agreement)
521519
if self.penalize_neg_agr:
522520
# Since the agreement is predicting scores between [0, 1], anything
523521
# under 0.5 should represent disagreement. Therefore, we want to encourage
@@ -712,17 +710,17 @@ def edge_iterator(self, data, batch_size, labeling):
712710
def _evaluate(self, indices, split, session, summary_writer):
713711
"""Evaluates the samples with the provided indices."""
714712
data_iterator_val = batch_iterator(
715-
indices,
716-
batch_size=self.batch_size,
717-
shuffle=False,
718-
allow_smaller_batch=True,
719-
repeat=False)
713+
indices,
714+
batch_size=self.batch_size,
715+
shuffle=False,
716+
allow_smaller_batch=True,
717+
repeat=False)
720718
feed_dict_val = self._construct_feed_dict(data_iterator_val, split)
721719
cummulative_acc = 0.0
722720
num_samples = 0
723721
while feed_dict_val is not None:
724722
val_acc, batch_size_actual = session.run(
725-
(self.accuracy, self.batch_size_actual), feed_dict=feed_dict_val)
723+
(self.accuracy, self.batch_size_actual), feed_dict=feed_dict_val)
726724
cummulative_acc += val_acc * batch_size_actual
727725
num_samples += batch_size_actual
728726
feed_dict_val = self._construct_feed_dict(data_iterator_val, split)
@@ -732,8 +730,8 @@ def _evaluate(self, indices, split, session, summary_writer):
732730
if self.enable_summaries:
733731
summary = tf.Summary()
734732
summary.value.add(
735-
tag='ClassificationModel/' + split + '_acc',
736-
simple_value=cummulative_acc)
733+
tag='ClassificationModel/' + split + '_acc',
734+
simple_value=cummulative_acc)
737735
iter_cls_total = session.run(self.iter_cls_total)
738736
summary_writer.add_summary(summary, iter_cls_total)
739737
summary_writer.flush()

0 commit comments

Comments
 (0)