|
| 1 | +"""Implements a voxel flow model.""" |
| 2 | +from __future__ import absolute_import |
| 3 | +from __future__ import division |
| 4 | +from __future__ import print_function |
| 5 | + |
| 6 | +import tensorflow as tf |
| 7 | +import tensorflow.contrib.slim as slim |
| 8 | +from utils.loss_utils import l1_loss, l2_loss, vae_loss |
| 9 | +from utils.geo_layer_utils import vae_gaussian_layer |
| 10 | +from utils.geo_layer_utils import bilinear_interp |
| 11 | +from utils.geo_layer_utils import meshgrid |
| 12 | + |
| 13 | +FLAGS = tf.app.flags.FLAGS |
| 14 | +epsilon = 0.001 |
| 15 | + |
| 16 | + |
| 17 | +class Voxel_flow_model(object): |
| 18 | + def __init__(self, is_train=True): |
| 19 | + self.is_train = is_train |
| 20 | + |
| 21 | + def inference(self, input_images): |
| 22 | + """Inference on a set of input_images. |
| 23 | + Args: |
| 24 | + """ |
| 25 | + return self._build_model(input_images) |
| 26 | + |
| 27 | + def total_var(self, images): |
| 28 | + pixel_dif1 = images[:, 1:, :, :] - images[:, :-1, :, :] |
| 29 | + pixel_dif2 = images[:, :, 1:, :] - images[:, :, :-1, :] |
| 30 | + tot_var = (tf.reduce_mean(tf.sqrt(tf.square(pixel_dif1) + epsilon**2)) + tf.reduce_mean(tf.sqrt(tf.square(pixel_dif2) + epsilon**2))) |
| 31 | + return tot_var |
| 32 | + |
| 33 | + def loss(self, predictions, targets): |
| 34 | + """Compute the necessary loss for training. |
| 35 | + Args: |
| 36 | + Returns: |
| 37 | + """ |
| 38 | + # self.reproduction_loss = l1_loss(predictions, targets) |
| 39 | + self.reproduction_loss = tf.reduce_mean(tf.sqrt(tf.square(predictions - targets) + epsilon**2)) |
| 40 | + |
| 41 | + self.motion_loss = self.total_var(self.flow) |
| 42 | + self.mask_loss = self.total_var(self.mask) |
| 43 | + |
| 44 | + # return [self.reproduction_loss, self.prior_loss] |
| 45 | + return self.reproduction_loss + 0.01 * self.motion_loss + 0.005 * self.mask_loss |
| 46 | + |
| 47 | + def l1loss(self, predictions, targets): |
| 48 | + self.reproduction_loss = l1_loss(predictions, targets) |
| 49 | + return self.reproduction_loss |
| 50 | + |
| 51 | + def _build_model(self, input_images): |
| 52 | + with slim.arg_scope([slim.conv2d], |
| 53 | + activation_fn=tf.nn.leaky_relu, |
| 54 | + weights_initializer=tf.truncated_normal_initializer(0.0, 0.01), |
| 55 | + weights_regularizer=slim.l2_regularizer(0.0001)): |
| 56 | + # Define network |
| 57 | + batch_norm_params = { |
| 58 | + 'decay': 0.9997, |
| 59 | + 'epsilon': 0.001, |
| 60 | + 'is_training': self.is_train, |
| 61 | + } |
| 62 | + with slim.arg_scope([slim.batch_norm], is_training=self.is_train, updates_collections=None): |
| 63 | + with slim.arg_scope([slim.conv2d], normalizer_fn=slim.batch_norm, |
| 64 | + normalizer_params=batch_norm_params): |
| 65 | + x0 = slim.conv2d(input_images, 32, [7, 7], stride=1, scope='conv1') |
| 66 | + x0_1 = slim.conv2d(x0, 32, [7, 7], stride=1, scope='conv1_1') |
| 67 | + |
| 68 | + net = slim.avg_pool2d(x0_1, [2, 2], scope='pool1') |
| 69 | + x1 = slim.conv2d(net, 64, [5, 5], stride=1, scope='conv2') |
| 70 | + x1_1 = slim.conv2d(x1, 64, [5, 5], stride=1, scope='conv2_1') |
| 71 | + |
| 72 | + net = slim.avg_pool2d(x1_1, [2, 2], scope='pool2') |
| 73 | + x2 = slim.conv2d(net, 128, [3, 3], stride=1, scope='conv3') |
| 74 | + x2_1 = slim.conv2d(x2, 128, [3, 3], stride=1, scope='conv3_1') |
| 75 | + |
| 76 | + net = slim.avg_pool2d(x2_1, [2, 2], scope='pool3') |
| 77 | + x3 = slim.conv2d(net, 256, [3, 3], stride=1, scope='conv4') |
| 78 | + x3_1 = slim.conv2d(x3, 256, [3, 3], stride=1, scope='conv4_1') |
| 79 | + |
| 80 | + net = slim.avg_pool2d(x3_1, [2, 2], scope='pool4') |
| 81 | + x4 = slim.conv2d(net, 512, [3, 3], stride=1, scope='conv5') |
| 82 | + x4_1 = slim.conv2d(x4, 512, [3, 3], stride=1, scope='conv5_1') |
| 83 | + |
| 84 | + net = slim.avg_pool2d(x4_1, [2, 2], scope='pool5') |
| 85 | + net = slim.conv2d(net, 512, [3, 3], stride=1, scope='conv6') |
| 86 | + net = slim.conv2d(net, 512, [3, 3], stride=1, scope='conv6_1') |
| 87 | + |
| 88 | + net = tf.image.resize_bilinear(net, [x4.get_shape().as_list()[1], x4.get_shape().as_list()[2]]) |
| 89 | + net = slim.conv2d(tf.concat([net, x4_1], -1), 512, [3, 3], stride=1, scope='conv7') |
| 90 | + net = slim.conv2d(net, 512, [3, 3], stride=1, scope='conv7_1') |
| 91 | + |
| 92 | + net = tf.image.resize_bilinear(net, [x3.get_shape().as_list()[1], x3.get_shape().as_list()[2]]) |
| 93 | + net = slim.conv2d(tf.concat([net, x3_1], -1), 256, [3, 3], stride=1, scope='conv8') |
| 94 | + net = slim.conv2d(net, 256, [3, 3], stride=1, scope='conv8_1') |
| 95 | + |
| 96 | + net = tf.image.resize_bilinear(net, [x2.get_shape().as_list()[1], x2.get_shape().as_list()[2]]) |
| 97 | + net = slim.conv2d(tf.concat([net, x2_1], -1), 128, [3, 3], stride=1, scope='conv9') |
| 98 | + net = slim.conv2d(net, 128, [3, 3], stride=1, scope='conv9_1') |
| 99 | + |
| 100 | + net = tf.image.resize_bilinear(net, [x1.get_shape().as_list()[1], x1.get_shape().as_list()[2]]) |
| 101 | + net = slim.conv2d(tf.concat([net, x1_1], -1), 64, [3, 3], stride=1, scope='conv10') |
| 102 | + net = slim.conv2d(net, 64, [3, 3], stride=1, scope='conv10_1') |
| 103 | + |
| 104 | + net = tf.image.resize_bilinear(net, [x0.get_shape().as_list()[1], x0.get_shape().as_list()[2]]) |
| 105 | + net = slim.conv2d(tf.concat([net, x0_1], -1), 32, [3, 3], stride=1, scope='conv11') |
| 106 | + y0 = slim.conv2d(net, 32, [3, 3], stride=1, scope='conv11_1') |
| 107 | + |
| 108 | + net = slim.conv2d(y0, 3, [5, 5], stride=1, activation_fn=tf.tanh, |
| 109 | + normalizer_fn=None, scope='conv12') |
| 110 | + net_copy = net |
| 111 | + |
| 112 | + flow = net[:, :, :, 0:2] |
| 113 | + mask = tf.expand_dims(net[:, :, :, 2], 3) |
| 114 | + |
| 115 | + self.flow = flow |
| 116 | + |
| 117 | + |
| 118 | + grid_x, grid_y = meshgrid(x0.get_shape().as_list()[1], x0.get_shape().as_list()[2]) |
| 119 | + grid_x = tf.tile(grid_x, [FLAGS.batch_size, 1, 1]) |
| 120 | + grid_y = tf.tile(grid_y, [FLAGS.batch_size, 1, 1]) |
| 121 | + |
| 122 | + flow = 0.5 * flow |
| 123 | + |
| 124 | + flow_ratio = tf.constant([255.0 / (x0.get_shape().as_list()[2]-1), 255.0 / (x0.get_shape().as_list()[1]-1)]) |
| 125 | + flow = flow * tf.expand_dims(tf.expand_dims(tf.expand_dims(flow_ratio, 0), 0), 0) |
| 126 | + |
| 127 | + coor_x_1 = grid_x + flow[:, :, :, 0] |
| 128 | + coor_y_1 = grid_y + flow[:, :, :, 1] |
| 129 | + |
| 130 | + coor_x_2 = grid_x - flow[:, :, :, 0] |
| 131 | + coor_y_2 = grid_y - flow[:, :, :, 1] |
| 132 | + |
| 133 | + output_1 = bilinear_interp(input_images[:, :, :, 0:3], coor_x_1, coor_y_1, 'interpolate') |
| 134 | + output_2 = bilinear_interp(input_images[:, :, :, 3:6], coor_x_2, coor_y_2, 'interpolate') |
| 135 | + |
| 136 | + self.warped_img1 = output_1 |
| 137 | + self.warped_img2 = output_2 |
| 138 | + |
| 139 | + self.warped_flow1 = bilinear_interp(-flow[:, :, :, 0:3]*0.5, coor_x_1, coor_y_1, 'interpolate') |
| 140 | + self.warped_flow2 = bilinear_interp(flow[:, :, :, 0:3]*0.5, coor_x_2, coor_y_2, 'interpolate') |
| 141 | + |
| 142 | + mask = 0.5 * (1.0 + mask) |
| 143 | + self.mask = mask |
| 144 | + mask = tf.tile(mask, [1, 1, 1, 3]) |
| 145 | + net = tf.multiply(mask, output_1) + tf.multiply(1.0 - mask, output_2) |
| 146 | + |
| 147 | + return [net, net_copy] |
0 commit comments