Skip to content

Commit

Permalink
Merge pull request #15 from rhee-airilab/master
Browse files Browse the repository at this point in the history
Fixing pre-1.0 API calls
  • Loading branch information
carpedm20 authored Oct 25, 2017
2 parents ec0a59b + db847a1 commit 1d98c28
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 47 deletions.
9 changes: 6 additions & 3 deletions cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ def _generate_image_and_label_batch(image, label, min_queue_examples,
min_after_dequeue=min_queue_examples)

# Display the training images in the visualizer.
tf.image_summary('images', images)
# FIXED pre-1.0 # tf.image_summary('images', images)
tf.summary.image('images', images)

return images, tf.reshape(label_batch, [batch_size])

Expand Down Expand Up @@ -171,7 +172,8 @@ def distorted_inputs(data_dir, batch_size):
lower=0.2, upper=1.8)

# Subtract off the mean and divide by the variance of the pixels.
float_image = tf.image.per_image_whitening(distorted_image)
# FIXED pre-1.0 # float_image = tf.image.per_image_whitening(distorted_image)
float_image = tf.image.per_image_standardization(distorted_image)

# Ensure that the random shuffling has good mixing properties.
min_fraction_of_examples_in_queue = 0.4
Expand Down Expand Up @@ -225,7 +227,8 @@ def inputs(eval_data, data_dir, batch_size):
width, height)

# Subtract off the mean and divide by the variance of the pixels.
float_image = tf.image.per_image_whitening(resized_image)
# FIXED pre-1.0 # float_image = tf.image.per_image_whitening(resized_image)
float_image = tf.image.per_image_standardization(resized_image)

# Ensure that the random shuffling has good mixing properties.
min_fraction_of_examples_in_queue = 0.4
Expand Down
16 changes: 8 additions & 8 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@
flags.DEFINE_integer("out_hidden_dims", 32, "dimesion of hidden states of output Conv layers")
flags.DEFINE_integer("out_recurrent_length", 2, "the length of output Conv layers")
flags.DEFINE_boolean("use_residual", False, "whether to use residual connections or not")
flags.DEFINE_boolean("use_dynamic_rnn", False, "whether to use dynamic_rnn or not")
# flags.DEFINE_boolean("use_dynamic_rnn", False, "whether to use dynamic_rnn or not")

# training
flags.DEFINE_float("max_epoch", 100000, "# of step in an epoch")
flags.DEFINE_float("test_step", 100, "# of step to test a model")
flags.DEFINE_float("save_step", 1000, "# of step to save a model")
flags.DEFINE_integer("max_epoch", 100000, "# of step in an epoch")
flags.DEFINE_integer("test_step", 100, "# of step to test a model")
flags.DEFINE_integer("save_step", 1000, "# of step to save a model")
flags.DEFINE_float("learning_rate", 1e-3, "learning rate")
flags.DEFINE_float("grad_clip", 1, "value of gradient to be used for clipping")
flags.DEFINE_boolean("use_gpu", True, "whether to use gpu for training")
Expand All @@ -52,7 +52,7 @@
np.random.seed(conf.random_seed)

def main(_):
model_dir = get_model_dir(conf,
model_dir = get_model_dir(conf,
['data_dir', 'sample_dir', 'max_epoch', 'test_step', 'save_step',
'is_train', 'random_seed', 'log_level', 'display'])
preprocess_conf(conf)
Expand All @@ -79,7 +79,7 @@ def main(_):
from cifar10 import IMAGE_SIZE, inputs

maybe_download_and_extract(DATA_DIR)
images, labels = inputs(eval_data=False,
images, labels = inputs(eval_data=False,
data_dir=os.path.join(DATA_DIR, 'cifar-10-batches-bin'), batch_size=conf.batch_size)

height, width, channel = IMAGE_SIZE, IMAGE_SIZE, 3
Expand All @@ -105,7 +105,7 @@ def main(_):

cost = network.test(images, with_update=True)
total_train_costs.append(cost)

# 2. test
total_test_costs = []
for idx in xrange(test_step_per_epoch):
Expand All @@ -121,7 +121,7 @@ def main(_):

# 3. generate samples
samples = network.generate()
save_images(samples, height, width, 10, 10,
save_images(samples, height, width, 10, 10,
directory=SAMPLE_DIR, prefix="epoch_%s" % epoch)

iterator.set_description("train l: %.3f, test l: %.3f" % (avg_train_cost, avg_test_cost))
Expand Down
26 changes: 18 additions & 8 deletions network.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self, sess, conf, height, width, channel):
self.l[scope] = conv2d(self.l['normalized_inputs'], conf.hidden_dims * 2, [7, 7], "A", scope=scope)
else:
self.l[scope] = conv2d(self.l['normalized_inputs'], conf.hidden_dims, [7, 7], "A", scope=scope)

# main reccurent layers
l_hid = self.l[scope]
for idx in xrange(conf.recurrent_length):
Expand All @@ -68,8 +68,11 @@ def __init__(self, sess, conf, height, width, channel):
self.l['output'] = tf.nn.sigmoid(self.l['conv2d_out_logits'])

logger.info("Building loss and optims")
# FIXED pre-1.0
# self.loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
# self.l['conv2d_out_logits'], self.l['normalized_inputs'], name='loss'))
self.loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
self.l['conv2d_out_logits'], self.l['normalized_inputs'], name='loss'))
logits=self.l['conv2d_out_logits'], labels=self.l['normalized_inputs'], name='loss'))
else:
raise ValueError("Implementation in progress for RGB colors")

Expand All @@ -82,10 +85,14 @@ def __init__(self, sess, conf, height, width, channel):
self.l['normalized_inputs_flat'] = tf.reshape(
self.l['normalized_inputs'], [-1, self.height * self.width, COLOR_DIM])

pred_pixels = [tf.squeeze(pixel, squeeze_dims=[1])
for pixel in tf.split(1, self.height * self.width, self.l['conv2d_out_logits_flat'])]
target_pixels = [tf.squeeze(pixel, squeeze_dims=[1])
for pixel in tf.split(1, self.height * self.width, self.l['normalized_inputs_flat'])]
# FIXED pre-1.0 # pred_pixels = [tf.squeeze(pixel, squeeze_dims=[1])
pred_pixels = [tf.squeeze(pixel, axis=[1])
# FIXED pre-1.0 # for pixel in tf.split(1, self.height * self.width, self.l['conv2d_out_logits_flat'])]
for pixel in tf.split(self.l['conv2d_out_logits_flat'], self.height * self.width, 1)]
# FIXED pre-1.0 # target_pixels = [tf.squeeze(pixel, squeeze_dims=[1])
target_pixels = [tf.squeeze(pixel, axis=[1])
# FIXED pre-1.0 # for pixel in tf.split(1, self.height * self.width, self.l['normalized_inputs_flat'])]
for pixel in tf.split(self.l['normalized_inputs_flat'], self.height * self.width, 1)]

softmaxed_pixels = [tf.nn.softmax(pixel) for pixel in pred_pixels]

Expand All @@ -96,16 +103,19 @@ def __init__(self, sess, conf, height, width, channel):
self.l['output'] = tf.nn.softmax(self.l['conv2d_out_logits'])

logger.info("Building loss and optims")
# FIXED pre-1.0
# self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
# self.l['conv2d_out_logits'], self.l['normalized_inputs'], name='loss'))
self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
self.l['conv2d_out_logits'], self.l['normalized_inputs'], name='loss'))
logits=self.l['conv2d_out_logits'], labels=self.l['normalized_inputs'], name='loss'))

optimizer = tf.train.RMSPropOptimizer(conf.learning_rate)
grads_and_vars = optimizer.compute_gradients(self.loss)

new_grads_and_vars = \
[(tf.clip_by_value(gv[0], -conf.grad_clip, conf.grad_clip), gv[1]) for gv in grads_and_vars]
self.optim = optimizer.apply_gradients(new_grads_and_vars)

show_all_variables()

logger.info("Building %s finished!" % conf.model)
Expand Down
69 changes: 44 additions & 25 deletions ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def get_shape(layer):
def skew(inputs, scope="skew"):
with tf.name_scope(scope):
batch, height, width, channel = get_shape(inputs) # [batch, height, width, channel]
rows = tf.split(1, height, inputs) # [batch, 1, width, channel]
# FIXED pre-1.0 # rows = tf.split(1, height, inputs) # [batch, 1, width, channel]
rows = tf.split(inputs, height, 1) # [batch, 1, width, channel]

new_width = width + height - 1
new_rows = []
Expand All @@ -46,7 +47,8 @@ def skew(inputs, scope="skew"):
assert get_shape(untransposed_row) == [batch, new_width, channel], "wrong shape of skewed row"
new_rows.append(untransposed_row)

outputs = tf.pack(new_rows, axis=1, name="output")
# FIXED pre-1.0 # outputs = tf.pack(new_rows, axis=1, name="output")
outputs = tf.stack(new_rows, axis=1, name="output")
assert get_shape(outputs) == [None, height, new_width, channel], "wrong shape of skewed output"

logger.debug('[skew] %s : %s %s -> %s %s' \
Expand All @@ -59,11 +61,13 @@ def unskew(inputs, width=None, scope="unskew"):
width = width if width else height

new_rows = []
rows = tf.split(1, height, inputs)
# FIXED pre-1.0 # rows = tf.split(1, height, inputs)
rows = tf.split(inputs, height, 1)

for idx, row in enumerate(rows):
new_rows.append(tf.slice(row, [0, 0, idx, 0], [-1, -1, width, -1]))
outputs = tf.concat(1, new_rows, name="output")
# FIXED pre-1.0 # outputs = tf.concat(1, new_rows, name="output")
outputs = tf.concat(new_rows, 1, name="output")

logger.debug('[unskew] %s : %s %s -> %s %s' \
% (scope, inputs.name, inputs.get_shape(), outputs.name, outputs.get_shape()))
Expand All @@ -79,7 +83,8 @@ def conv2d(
activation_fn=None,
weights_initializer=WEIGHT_INITIALIZER,
weights_regularizer=None,
biases_initializer=tf.zeros_initializer,
# FIXED pre-1.0 # biases_initializer=tf.zeros_initializer,
biases_initializer=tf.zeros_initializer(),
biases_regularizer=None,
scope="conv2d"):
with tf.variable_scope(scope):
Expand Down Expand Up @@ -138,7 +143,8 @@ def conv1d(
activation_fn=None,
weights_initializer=WEIGHT_INITIALIZER,
weights_regularizer=None,
biases_initializer=tf.zeros_initializer,
# FIXED pre-1.0 # biases_initializer=tf.zeros_initializer,
biases_initializer=tf.zeros_initializer(),
biases_regularizer=None,
scope="conv1d"):
with tf.variable_scope(scope):
Expand All @@ -152,7 +158,7 @@ def conv1d(
tf.float32, weights_initializer, weights_regularizer)
tf.add_to_collection('conv1d_weights', weights)

outputs = tf.nn.conv2d(inputs,
outputs = tf.nn.conv2d(inputs,
weights, [1, stride_h, stride_w, 1], padding=padding, name='outputs')
tf.add_to_collection('conv1d_outputs', weights)

Expand All @@ -172,7 +178,8 @@ def conv1d(
def diagonal_bilstm(inputs, conf, scope='diagonal_bilstm'):
with tf.variable_scope(scope):
def reverse(inputs):
return tf.reverse(inputs, [False, False, True, False])
# FIXED pre-1.0 # return tf.reverse(inputs, [False, False, True, False])
return tf.reverse(inputs, [2]) # [False, False, True, False])

output_state_fw = diagonal_lstm(inputs, conf, scope='output_state_fw')
output_state_bw = reverse(diagonal_lstm(reverse(inputs), conf, scope='output_state_bw'))
Expand All @@ -198,7 +205,8 @@ def reverse(inputs):
output_state_bw_only_last = tf.slice(output_state_bw, [0, height-1, 0, 0], [-1, 1, -1, -1])
dummy_zeros = tf.zeros_like(output_state_bw_only_last)

output_state_bw_with_last_zeros = tf.concat(1, [output_state_bw_except_last, dummy_zeros])
# FIXED pre-1.0 # output_state_bw_with_last_zeros = tf.concat(1, [output_state_bw_except_last, dummy_zeros])
output_state_bw_with_last_zeros = tf.concat([output_state_bw_except_last, dummy_zeros], 1)

tf.add_to_collection('output_state_bw_with_last_zeros', output_state_bw_with_last_zeros)

Expand All @@ -225,27 +233,35 @@ def diagonal_lstm(inputs, conf, scope='diagonal_lstm'):

tf.add_to_collection('rnn_inputs', rnn_inputs)

rnn_input_list = [tf.squeeze(rnn_input, squeeze_dims=[1])
for rnn_input in tf.split(split_dim=1, num_split=width, value=rnn_inputs)]
# FIXED pre-1.0 # rnn_input_list = [tf.squeeze(rnn_input, squeeze_dims=[1])
rnn_input_list = [tf.squeeze(rnn_input, axis=[1])
# FIXED pre-1.0 # for rnn_input in tf.split(split_dim=1, num_split=width, value=rnn_inputs)]
for rnn_input in tf.split(rnn_inputs, width, 1)]

cell = DiagonalLSTMCell(conf.hidden_dims, height, channel)

if conf.use_dynamic_rnn:
# if conf.use_dynamic_rnn:
if True:
# XXX FIXME: sequence_length ?
outputs, states = tf.nn.dynamic_rnn(cell,
inputs=rnn_inputs, dtype=tf.float32) # [batch, width, height * hidden_dims]
else:
output_list, state_list = tf.nn.rnn(cell,
inputs=rnn_input_list, dtype=tf.float32) # width * [batch, height * hidden_dims]
packed_outputs = outputs # dynaic_rnn(), [batch, width, height * hidden_dims]

# else:
# output_list, state_list = tf.nn.rnn(cell,
# inputs=rnn_input_list, dtype=tf.float32) # width * [batch, height * hidden_dims]

# # FIXED pre-1.0 # packed_outputs = tf.pack(output_list, 1) # [batch, width, height * hidden_dims]
# packed_outputs = tf.stack(output_list, 1) # [batch, width, height * hidden_dims]

packed_outputs = tf.pack(output_list, 1) # [batch, width, height * hidden_dims]
width_first_outputs = tf.reshape(packed_outputs,
[-1, width, height, conf.hidden_dims]) # [batch, width, height, hidden_dims]
width_first_outputs = tf.reshape(packed_outputs,
[-1, width, height, conf.hidden_dims]) # [batch, width, height, hidden_dims]

skewed_outputs = tf.transpose(width_first_outputs, [0, 2, 1, 3])
tf.add_to_collection('skewed_outputs', skewed_outputs)
skewed_outputs = tf.transpose(width_first_outputs, [0, 2, 1, 3])
tf.add_to_collection('skewed_outputs', skewed_outputs)

outputs = unskew(skewed_outputs)
tf.add_to_collection('unskewed_outputs', outputs)
outputs = unskew(skewed_outputs)
tf.add_to_collection('unskewed_outputs', outputs)

return outputs

Expand Down Expand Up @@ -299,15 +315,18 @@ def __call__(self, i_to_s, state, scope="DiagonalBiLSTMCell"):
lstm_matrix = tf.sigmoid(s_to_s + i_to_s)

# i = input_gate, g = new_input, f = forget_gate, o = output_gate
i, g, f, o = tf.split(1, 4, lstm_matrix)
# FIXED pre-1.0 # i, g, f, o = tf.split(1, 4, lstm_matrix)
i, g, f, o = tf.split(lstm_matrix, 4, 1)

c = f * c_prev + i * g
h = tf.mul(o, tf.tanh(c), name='hid')
# FIXED pre-1.0 # h = tf.mul(o, tf.tanh(c), name='hid')
h = tf.multiply(o, tf.tanh(c), name='hid')

logger.debug('[DiagonalLSTMCell] %s : %s %s -> %s %s' \
% (scope, i_to_s.name, i_to_s.get_shape(), h.name, h.get_shape()))

new_state = tf.concat(1, [c, h])
# FIXED pre-1.0 # new_state = tf.concat(1, [c, h])
new_state = tf.concat([c, h], 1)
return h, new_state

class RowLSTMCell(rnn_cell.RNNCell):
Expand Down
9 changes: 6 additions & 3 deletions statistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ def __init__(self, sess, data, model_dir, variables, test_step, max_to_keep=20):

self.model_dir = model_dir
self.saver = tf.train.Saver(variables + [self.t_op], max_to_keep=max_to_keep)
self.writer = tf.train.SummaryWriter('./logs/%s' % self.model_dir, self.sess.graph)
# FIXED pre-1.0 # self.writer = tf.train.SummaryWriter('./logs/%s' % self.model_dir, self.sess.graph)
self.writer = tf.summary.FileWriter('./logs/%s' % self.model_dir, self.sess.graph)

with tf.variable_scope('summary'):
scalar_summary_tags = ['train_l', 'test_l']
Expand All @@ -27,7 +28,8 @@ def __init__(self, sess, data, model_dir, variables, test_step, max_to_keep=20):

for tag in scalar_summary_tags:
self.summary_placeholders[tag] = tf.placeholder('float32', None, name=tag.replace(' ', '_'))
self.summary_ops[tag] = tf.scalar_summary('%s/%s' % (data, tag), self.summary_placeholders[tag])
# FIXED pre-1.0 # self.summary_ops[tag] = tf.scalar_summary('%s/%s' % (data, tag), self.summary_placeholders[tag])
self.summary_ops[tag] = tf.summary.scalar('%s/%s' % (data, tag), self.summary_placeholders[tag])

def reset(self):
pass
Expand Down Expand Up @@ -60,7 +62,8 @@ def save_model(self, t):

def load_model(self):
logger.info("Initializing all variables")
tf.initialize_all_variables().run()
# FIXED pre-1.0 # tf.initialize_all_variables().run()
tf.global_variables_initializer().run()

logger.info("Loading checkpoints...")
ckpt = tf.train.get_checkpoint_state(self.model_dir)
Expand Down

0 comments on commit 1d98c28

Please sign in to comment.