Skip to content

Commit

Permalink
Use tf.select instead of tf.where (compat TF 0.11)
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Dec 17, 2016
1 parent 914d976 commit 30fa61d
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions keras/backend/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1707,7 +1707,7 @@ def rnn(step_function, inputs, initial_states,
for input, mask_t in zip(input_list, mask_list):
output, new_states = step_function(input, states + constants)

# tf.where needs its condition tensor
# tf.select needs its condition tensor
# to be the same shape as its two
# result tensors, but in our case
# the condition (mask) tensor is
Expand All @@ -1725,16 +1725,16 @@ def rnn(step_function, inputs, initial_states,
else:
prev_output = successive_outputs[-1]

output = tf.where(tiled_mask_t, output, prev_output)
output = tf.select(tiled_mask_t, output, prev_output)

return_states = []
for state, new_state in zip(states, new_states):
# (see earlier comment for tile explanation)
tiled_mask_t = tf.tile(mask_t,
stack([1, tf.shape(new_state)[1]]))
return_states.append(tf.where(tiled_mask_t,
new_state,
state))
return_states.append(tf.select(tiled_mask_t,
new_state,
state))
states = return_states
successive_outputs.append(output)
successive_states.append(states)
Expand Down Expand Up @@ -1795,8 +1795,8 @@ def _step(time, output_ta_t, *states):
new_state.set_shape(state.get_shape())
tiled_mask_t = tf.tile(mask_t,
stack([1, tf.shape(output)[1]]))
output = tf.where(tiled_mask_t, output, states[0])
new_states = [tf.where(tiled_mask_t, new_states[i], states[i]) for i in range(len(states))]
output = tf.select(tiled_mask_t, output, states[0])
new_states = [tf.select(tiled_mask_t, new_states[i], states[i]) for i in range(len(states))]
output_ta_t = output_ta_t.write(time, output)
return (time + 1, output_ta_t) + tuple(new_states)
else:
Expand Down Expand Up @@ -1921,7 +1921,7 @@ def elu(x, alpha=1.):
if alpha == 1:
return res
else:
return tf.where(x > 0, res, alpha * res)
return tf.select(x > 0, res, alpha * res)


def softmax(x):
Expand Down Expand Up @@ -2384,9 +2384,9 @@ def random_uniform(shape, low=0.0, high=1.0, dtype=_FLOATX, seed=None):
def random_binomial(shape, p=0.0, dtype=_FLOATX, seed=None):
if seed is None:
seed = np.random.randint(10e6)
return tf.where(tf.random_uniform(shape, dtype=dtype, seed=seed) <= p,
tf.ones(shape, dtype=dtype),
tf.zeros(shape, dtype=dtype))
return tf.select(tf.random_uniform(shape, dtype=dtype, seed=seed) <= p,
tf.ones(shape, dtype=dtype),
tf.zeros(shape, dtype=dtype))


# CTC
Expand Down

0 comments on commit 30fa61d

Please sign in to comment.