Skip to content

Commit

Permalink
Make sure that changes to the global floatx are effectively taken int…
Browse files Browse the repository at this point in the history
…o account by the backend. (keras-team#4739)
  • Loading branch information
jphalip authored and fchollet committed Dec 17, 2016
1 parent 30fa61d commit 79406f1
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 34 deletions.
61 changes: 40 additions & 21 deletions keras/backend/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import os
import copy
import warnings
from .common import _FLOATX, _EPSILON, image_dim_ordering, reset_uids
from .common import floatx, _EPSILON, image_dim_ordering, reset_uids
py_all = all

# INTERNAL UTILS
Expand Down Expand Up @@ -207,7 +207,7 @@ def to_dense(tensor):
return tensor


def variable(value, dtype=_FLOATX, name=None):
def variable(value, dtype=None, name=None):
'''Instantiates a variable and returns it.
# Arguments
Expand All @@ -232,6 +232,8 @@ def variable(value, dtype=_FLOATX, name=None):
[ 3., 4.]])
```
'''
if dtype is None:
dtype = floatx()
if hasattr(value, 'tocoo'):
sparse_coo = value.tocoo()
indices = np.concatenate((np.expand_dims(sparse_coo.row, 1),
Expand Down Expand Up @@ -271,7 +273,7 @@ def _initialize_variables():
sess.run(tf.initialize_variables(uninitialized_variables))


def placeholder(shape=None, ndim=None, dtype=_FLOATX, sparse=False, name=None):
def placeholder(shape=None, ndim=None, dtype=None, sparse=False, name=None):
'''Instantiates a placeholder tensor and returns it.
# Arguments
Expand All @@ -296,6 +298,8 @@ def placeholder(shape=None, ndim=None, dtype=_FLOATX, sparse=False, name=None):
<tf.Tensor 'Placeholder_4:0' shape=(2, 4, 5) dtype=float32>
```
'''
if dtype is None:
dtype = floatx()
if not shape:
if ndim:
shape = tuple([None for _ in range(ndim)])
Expand Down Expand Up @@ -448,7 +452,7 @@ def eval(x):
return to_dense(x).eval(session=get_session())


def zeros(shape, dtype=_FLOATX, name=None):
def zeros(shape, dtype=None, name=None):
'''Instantiates an all-zeros variable and returns it.
# Arguments
Expand All @@ -469,13 +473,15 @@ def zeros(shape, dtype=_FLOATX, name=None):
[ 0., 0., 0., 0.]], dtype=float32)
```
'''
if dtype is None:
dtype = floatx()
shape = tuple(map(int, shape))
tf_dtype = _convert_string_dtype(dtype)
return variable(tf.constant_initializer(0., dtype=tf_dtype)(shape),
dtype, name)


def ones(shape, dtype=_FLOATX, name=None):
def ones(shape, dtype=None, name=None):
'''Instantiates an all-ones tensor variable and returns it.
# Arguments
Expand All @@ -498,13 +504,15 @@ def ones(shape, dtype=_FLOATX, name=None):
[ 1., 1., 1., 1.]], dtype=float32)
```
'''
if dtype is None:
dtype = floatx()
shape = tuple(map(int, shape))
tf_dtype = _convert_string_dtype(dtype)
return variable(tf.constant_initializer(1., dtype=tf_dtype)(shape),
dtype, name)


def eye(size, dtype=_FLOATX, name=None):
def eye(size, dtype=None, name=None):
'''Instantiate an identity matrix and returns it.
# Arguments
Expand Down Expand Up @@ -577,7 +585,7 @@ def ones_like(x, name=None):
return tf.ones_like(x, name=name)


def random_uniform_variable(shape, low, high, dtype=_FLOATX,
def random_uniform_variable(shape, low, high, dtype=None,
name=None, seed=None):
'''Instantiates an Keras variable filled with
samples drawn from a uniform distribution and returns it.
Expand Down Expand Up @@ -609,6 +617,8 @@ def random_uniform_variable(shape, low, high, dtype=_FLOATX,
[ 0.66137183, 0.00869417, 0.89220798]], dtype=float32)
```
'''
if dtype is None:
dtype = floatx()
shape = tuple(map(int, shape))
tf_dtype = _convert_string_dtype(dtype)
if seed is None:
Expand All @@ -619,7 +629,7 @@ def random_uniform_variable(shape, low, high, dtype=_FLOATX,
return variable(value, dtype=dtype, name=name)


def random_normal_variable(shape, mean, scale, dtype=_FLOATX,
def random_normal_variable(shape, mean, scale, dtype=None,
name=None, seed=None):
'''Instantiates an Keras variable filled with
samples drawn from a normal distribution and returns it.
Expand Down Expand Up @@ -651,6 +661,8 @@ def random_normal_variable(shape, mean, scale, dtype=_FLOATX,
[ 0.92629528, 0.28055015, 1.70484698]], dtype=float32)
```
'''
if dtype is None:
dtype = floatx()
shape = tuple(map(int, shape))
tf_dtype = _convert_string_dtype(dtype)
if seed is None:
Expand Down Expand Up @@ -963,7 +975,7 @@ def var(x, axis=None, keepdims=False):
'''
axis = _normalize_axis(axis, ndim(x))
if x.dtype.base_dtype == tf.bool:
x = tf.cast(x, _FLOATX)
x = tf.cast(x, floatx())
m = tf.reduce_mean(x, reduction_indices=axis, keep_dims=True)
devs_squared = tf.square(x - m)
return tf.reduce_mean(devs_squared,
Expand All @@ -982,7 +994,7 @@ def mean(x, axis=None, keepdims=False):
'''
axis = _normalize_axis(axis, ndim(x))
if x.dtype.base_dtype == tf.bool:
x = tf.cast(x, _FLOATX)
x = tf.cast(x, floatx())
return tf.reduce_mean(x, reduction_indices=axis, keep_dims=keepdims)


Expand Down Expand Up @@ -2073,7 +2085,7 @@ def _preprocess_deconv_output_shape(shape, dim_ordering):


def _preprocess_conv2d_input(x, dim_ordering):
if _FLOATX == 'float64':
if dtype(x) == 'float64':
x = tf.cast(x, 'float32')
if dim_ordering == 'th':
# TF uses the last dimension as channel dimension,
Expand All @@ -2085,7 +2097,7 @@ def _preprocess_conv2d_input(x, dim_ordering):


def _preprocess_conv3d_input(x, dim_ordering):
if _FLOATX == 'float64':
if dtype(x) == 'float64':
x = tf.cast(x, 'float32')
if dim_ordering == 'th':
# TF uses the last dimension as channel dimension,
Expand All @@ -2097,7 +2109,7 @@ def _preprocess_conv3d_input(x, dim_ordering):


def _preprocess_conv2d_kernel(kernel, dim_ordering):
if _FLOATX == 'float64':
if dtype(kernel) == 'float64':
kernel = tf.cast(kernel, 'float32')
if dim_ordering == 'th':
# TF uses the last dimension as channel dimension,
Expand All @@ -2109,7 +2121,7 @@ def _preprocess_conv2d_kernel(kernel, dim_ordering):


def _preprocess_conv3d_kernel(kernel, dim_ordering):
if _FLOATX == 'float64':
if dtype(kernel) == 'float64':
kernel = tf.cast(kernel, 'float32')
if dim_ordering == 'th':
# TF uses the last dimension as channel dimension,
Expand All @@ -2134,7 +2146,7 @@ def _postprocess_conv2d_output(x, dim_ordering):
if dim_ordering == 'th':
x = tf.transpose(x, (0, 3, 1, 2))

if _FLOATX == 'float64':
if floatx() == 'float64':
x = tf.cast(x, 'float64')
return x

Expand All @@ -2143,7 +2155,7 @@ def _postprocess_conv3d_output(x, dim_ordering):
if dim_ordering == 'th':
x = tf.transpose(x, (0, 4, 1, 2, 3))

if _FLOATX == 'float64':
if floatx() == 'float64':
x = tf.cast(x, 'float64')
return x

Expand All @@ -2158,13 +2170,14 @@ def conv1d(x, kernel, stride=1, border_mode='valid',
border_mode: string, "same" or "valid".
'''
# pre-process dtype
if _FLOATX == 'float64':
x_dtype = dtype(x)
if x_dtype == 'float64':
x = tf.cast(x, 'float32')
kernel = tf.cast(kernel, 'float32')
padding = _preprocess_border_mode(border_mode)
x = tf.nn.conv1d(x, kernel, stride, padding=padding)
# post-process dtype
if _FLOATX == 'float64':
if x_dtype == 'float64':
x = tf.cast(x, 'float64')
return x

Expand Down Expand Up @@ -2367,21 +2380,27 @@ def pool3d(x, pool_size, strides=(1, 1, 1), border_mode='valid',

# RANDOMNESS

def random_normal(shape, mean=0.0, std=1.0, dtype=_FLOATX, seed=None):
def random_normal(shape, mean=0.0, std=1.0, dtype=None, seed=None):
if dtype is None:
dtype = floatx()
if seed is None:
seed = np.random.randint(10e6)
return tf.random_normal(shape, mean=mean, stddev=std,
dtype=dtype, seed=seed)


def random_uniform(shape, low=0.0, high=1.0, dtype=_FLOATX, seed=None):
def random_uniform(shape, low=0.0, high=1.0, dtype=None, seed=None):
if dtype is None:
dtype = floatx()
if seed is None:
seed = np.random.randint(10e6)
return tf.random_uniform(shape, minval=low, maxval=high,
dtype=dtype, seed=seed)


def random_binomial(shape, p=0.0, dtype=_FLOATX, seed=None):
def random_binomial(shape, p=0.0, dtype=None, seed=None):
if dtype is None:
dtype = floatx()
if seed is None:
seed = np.random.randint(10e6)
return tf.select(tf.random_uniform(shape, dtype=dtype, seed=seed) <= p,
Expand Down
40 changes: 28 additions & 12 deletions keras/backend/theano_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from theano.sandbox.softsign import softsign as T_softsign
import inspect
import numpy as np
from .common import _FLOATX, _EPSILON, image_dim_ordering
from .common import _FLOATX, floatx, _EPSILON, image_dim_ordering
py_all = all


Expand Down Expand Up @@ -56,9 +56,11 @@ def to_dense(tensor):
return tensor


def variable(value, dtype=_FLOATX, name=None):
def variable(value, dtype=None, name=None):
'''Instantiates a variable.
'''
if dtype is None:
dtype = floatx()
if hasattr(value, 'tocoo'):
_assert_sparse_module()
return th_sparse_module.as_sparse_variable(value)
Expand All @@ -67,9 +69,11 @@ def variable(value, dtype=_FLOATX, name=None):
return theano.shared(value=value, name=name, strict=False)


def placeholder(shape=None, ndim=None, dtype=_FLOATX, sparse=False, name=None):
def placeholder(shape=None, ndim=None, dtype=None, sparse=False, name=None):
'''Instantiate an input data placeholder variable.
'''
if dtype is None:
dtype = floatx()
if shape is None and ndim is None:
raise ValueError('Specify either a shape or ndim value.')
if shape is not None:
Expand Down Expand Up @@ -111,21 +115,27 @@ def eval(x):
return to_dense(x).eval()


def zeros(shape, dtype=_FLOATX, name=None):
def zeros(shape, dtype=None, name=None):
'''Instantiates an all-zeros variable.
'''
if dtype is None:
dtype = floatx()
return variable(np.zeros(shape), dtype, name)


def ones(shape, dtype=_FLOATX, name=None):
def ones(shape, dtype=None, name=None):
'''Instantiates an all-ones variable.
'''
if dtype is None:
dtype = floatx()
return variable(np.ones(shape), dtype, name)


def eye(size, dtype=_FLOATX, name=None):
def eye(size, dtype=None, name=None):
'''Instantiates an identity matrix.
'''
if dtype is None:
dtype = floatx()
return variable(np.eye(size), dtype, name)


Expand All @@ -137,12 +147,12 @@ def zeros_like(x, name=None):
return T.zeros_like(x)


def random_uniform_variable(shape, low, high, dtype=_FLOATX, name=None):
def random_uniform_variable(shape, low, high, dtype=None, name=None):
return variable(np.random.uniform(low=low, high=high, size=shape),
dtype=dtype, name=name)


def random_normal_variable(shape, mean, scale, dtype=_FLOATX, name=None):
def random_normal_variable(shape, mean, scale, dtype=None, name=None):
return variable(np.random.normal(loc=0.0, scale=scale, size=shape),
dtype=dtype, name=name)

Expand Down Expand Up @@ -284,7 +294,7 @@ def mean(x, axis=None, keepdims=False):
dtype = None
# bool is available since theano v0.9dev
if 'int' in x.dtype or x.dtype == 'bool':
dtype = _FLOATX
dtype = floatx()
return T.mean(x, axis=axis, keepdims=keepdims, dtype=dtype)


Expand Down Expand Up @@ -1799,21 +1809,27 @@ def _old_theano_pool3d(x, pool_size, strides=(1, 1, 1), border_mode='valid',
# RANDOMNESS


def random_normal(shape, mean=0.0, std=1.0, dtype=_FLOATX, seed=None):
def random_normal(shape, mean=0.0, std=1.0, dtype=None, seed=None):
if dtype is None:
dtype = floatx()
if seed is None:
seed = np.random.randint(1, 10e6)
rng = RandomStreams(seed=seed)
return rng.normal(size=shape, avg=mean, std=std, dtype=dtype)


def random_uniform(shape, low=0.0, high=1.0, dtype=_FLOATX, seed=None):
def random_uniform(shape, low=0.0, high=1.0, dtype=None, seed=None):
if dtype is None:
dtype = floatx()
if seed is None:
seed = np.random.randint(1, 10e6)
rng = RandomStreams(seed=seed)
return rng.uniform(shape, low=low, high=high, dtype=dtype)


def random_binomial(shape, p=0.0, dtype=_FLOATX, seed=None):
def random_binomial(shape, p=0.0, dtype=None, seed=None):
if dtype is None:
dtype = floatx()
if seed is None:
seed = np.random.randint(1, 10e6)
rng = RandomStreams(seed=seed)
Expand Down
Loading

0 comments on commit 79406f1

Please sign in to comment.