Skip to content

Commit

Permalink
modify batch norm to make the model trainable, more efficient interle…
Browse files Browse the repository at this point in the history
…aving (jtatusko suggestion)
  • Loading branch information
he-dhamo committed Sep 22, 2017
1 parent e47e593 commit fbd4176
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 79 deletions.
110 changes: 46 additions & 64 deletions tensorflow/models/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,23 @@

DEFAULT_PADDING = 'SAME'


def get_incoming_shape(incoming):
""" Returns the incoming data shape """
if isinstance(incoming, tf.Tensor):
return incoming.get_shape().as_list()
elif type(incoming) in [np.array, list, tuple]:
return np.shape(incoming)
else:
raise Exception("Invalid incoming layer.")


def interleave(tensors, axis):
old_shape = get_incoming_shape(tensors[0])[1:]
new_shape = [-1] + old_shape
new_shape[axis] *= len(tensors)
return tf.reshape(tf.stack(tensors, axis + 1), new_shape)

def layer(op):
'''Decorator for composable network layers.'''

Expand Down Expand Up @@ -39,7 +56,7 @@ def layer_decorated(self, *args, **kwargs):

class Network(object):

def __init__(self, inputs, batch, trainable=False):
def __init__(self, inputs, batch, keep_prob, is_training, trainable = True):
# The input nodes for this network
self.inputs = inputs
# The current list of terminal nodes
Expand All @@ -48,9 +65,9 @@ def __init__(self, inputs, batch, trainable=False):
self.layers = dict(inputs)
# If true, the resulting variables are set as trainable
self.trainable = trainable
# Batch size needs to be set for the implementation of the interleaving
self.batch_size = batch

self.keep_prob = keep_prob
self.is_training = is_training
self.setup()


Expand All @@ -65,10 +82,9 @@ def load(self, data_path, session, ignore_missing=False):
ignore_missing: If true, serialized weights for missing layers are ignored.
'''
data_dict = np.load(data_path, encoding='latin1').item()
for op_name in data_dict:

for op_name in data_dict:
with tf.variable_scope(op_name, reuse=True):
for param_name, data in iter(data_dict[op_name].items()):
for param_name, data in iter(data_dict[op_name].items()):
try:
var = tf.get_variable(param_name)
session.run(var.assign(data))
Expand Down Expand Up @@ -239,51 +255,42 @@ def softmax(self, input_data, name):

@layer
def batch_normalization(self, input_data, name, scale_offset=True, relu=False):
# NOTE: Currently, only inference is supported

with tf.variable_scope(name) as scope:
shape = [input_data.get_shape()[-1]]
pop_mean = tf.get_variable("mean", shape, initializer = tf.constant_initializer(0.0), trainable=False)
pop_var = tf.get_variable("variance", shape, initializer = tf.constant_initializer(1.0), trainable=False)
epsilon = 1e-4
decay = 0.999
if scale_offset:
scale = self.make_var('scale', shape=shape)
offset = self.make_var('offset', shape=shape)
scale = tf.get_variable("scale", shape, initializer = tf.constant_initializer(1.0))
offset = tf.get_variable("offset", shape, initializer = tf.constant_initializer(0.0))
else:
scale, offset = (None, None)
output = tf.nn.batch_normalization(
input_data,
mean=self.make_var('mean', shape=shape),
variance=self.make_var('variance', shape=shape),
offset=offset,
scale=scale,
variance_epsilon=1e-4,
name=name)
if self.is_training:
batch_mean, batch_var = tf.nn.moments(input_data, [0, 1, 2])

train_mean = tf.assign(pop_mean,
pop_mean * decay + batch_mean * (1 - decay))
train_var = tf.assign(pop_var,
pop_var * decay + batch_var * (1 - decay))
with tf.control_dependencies([train_mean, train_var]):
output = tf.nn.batch_normalization(input_data,
batch_mean, batch_var, offset, scale, epsilon, name = name)
else:
output = tf.nn.batch_normalization(input_data,
pop_mean, pop_var, offset, scale, epsilon, name = name)

if relu:
output = tf.nn.relu(output)


return output

@layer
def dropout(self, input_data, keep_prob, name):
return tf.nn.dropout(input_data, keep_prob, name=name)


# -------------------------------------------------------
# Additional operations, specific to FCRN
# -------------------------------------------------------

def prepare_indices(self, before, row, col, after, dims ):

x0, x1, x2, x3 = np.meshgrid(before, row, col, after)

x_0 = tf.Variable(x0.reshape([-1]), name = 'x_0')
x_1 = tf.Variable(x1.reshape([-1]), name = 'x_1')
x_2 = tf.Variable(x2.reshape([-1]), name = 'x_2')
x_3 = tf.Variable(x3.reshape([-1]), name = 'x_3')

linear_indices = x_3 + dims[3].value * x_2 + 2 * dims[2].value * dims[3].value * x_0 * 2 * dims[1].value + 2 * dims[2].value * dims[3].value * x_1
linear_indices_int = tf.to_int32(linear_indices)

return linear_indices_int

def unpool_as_conv(self, size, input_data, id, stride = 1, ReLU = False, BN = True):

# Model upconvolutions (unpooling + convolution) as interleaving feature
Expand Down Expand Up @@ -323,34 +330,9 @@ def unpool_as_conv(self, size, input_data, id, stride = 1, ReLU = False, BN = Tr

# Interleaving elements of the four feature maps
# --------------------------------------------------
dims = outputA.get_shape()
dim1 = dims[1] * 2
dim2 = dims[2] * 2

A_row_indices = range(0, dim1, 2)
A_col_indices = range(0, dim2, 2)
B_row_indices = range(1, dim1, 2)
B_col_indices = range(0, dim2, 2)
C_row_indices = range(0, dim1, 2)
C_col_indices = range(1, dim2, 2)
D_row_indices = range(1, dim1, 2)
D_col_indices = range(1, dim2, 2)

all_indices_before = range(int(self.batch_size))
all_indices_after = range(dims[3])

A_linear_indices = self.prepare_indices(all_indices_before, A_row_indices, A_col_indices, all_indices_after, dims)
B_linear_indices = self.prepare_indices(all_indices_before, B_row_indices, B_col_indices, all_indices_after, dims)
C_linear_indices = self.prepare_indices(all_indices_before, C_row_indices, C_col_indices, all_indices_after, dims)
D_linear_indices = self.prepare_indices(all_indices_before, D_row_indices, D_col_indices, all_indices_after, dims)

A_flat = tf.reshape(tf.transpose(outputA, [1, 0, 2, 3]), [-1])
B_flat = tf.reshape(tf.transpose(outputB, [1, 0, 2, 3]), [-1])
C_flat = tf.reshape(tf.transpose(outputC, [1, 0, 2, 3]), [-1])
D_flat = tf.reshape(tf.transpose(outputD, [1, 0, 2, 3]), [-1])

Y_flat = tf.dynamic_stitch([A_linear_indices, B_linear_indices, C_linear_indices, D_linear_indices], [A_flat, B_flat, C_flat, D_flat])
Y = tf.reshape(Y_flat, shape = tf.to_int32([-1, dim1.value, dim2.value, dims[3].value]))
left = interleave([outputA, outputB], axis=1) # columns
right = interleave([outputC, outputD], axis=1) # columns
Y = interleave([left, right], axis=2) # rows

if BN:
layerName = "layer%s_BN" % (id)
Expand Down
27 changes: 12 additions & 15 deletions tensorflow/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@

def predict(model_data_path, image_path):


# Default input size
height = 228
width = 304
channels = 3
batch_size = 1

# Read image
img = Image.open(image_path)
img = img.resize([width,height], Image.ANTIALIAS)
Expand All @@ -23,26 +24,22 @@ def predict(model_data_path, image_path):

# Create a placeholder for the input image
input_node = tf.placeholder(tf.float32, shape=(None, height, width, channels))

# Construct the network
net = models.ResNet50UpProj({'data': input_node}, batch_size)
net = models.ResNet50UpProj({'data': input_node}, batch_size, 1, False)

with tf.Session() as sess:

# Load the converted parameters
print('Loading the model')
net.load(model_data_path, sess)

uninitialized_vars = []
for var in tf.global_variables():
try:
sess.run(var)
except tf.errors.FailedPreconditionError:
uninitialized_vars.append(var)

init_new_vars_op = tf.variables_initializer(uninitialized_vars)
sess.run(init_new_vars_op)

# Use to load from ckpt file
saver = tf.train.Saver()
saver.restore(sess, model_data_path)

# Use to load from npy file
#net.load(model_data_path, sess)

# Evalute the network for the given image
pred = sess.run(net.get_output(), feed_dict={input_node: img})

Expand All @@ -51,7 +48,7 @@ def predict(model_data_path, image_path):
ii = plt.imshow(pred[0,:,:,0], interpolation='nearest')
fig.colorbar(ii)
plt.show()

return pred


Expand Down

0 comments on commit fbd4176

Please sign in to comment.