Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add cast to Block and Parameter. Implicit dtype casting is removed. (#…
Browse files Browse the repository at this point in the history
…8735)

* fix

* fix

* fix

* fix

* Update parameter.py
  • Loading branch information
piiswrong authored Nov 21, 2017
1 parent 12c3fb1 commit 1852e2f
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 45 deletions.
18 changes: 17 additions & 1 deletion python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,19 @@ def hybridize(self, active=True):
for cld in self._children:
cld.hybridize(active)

def cast(self, dtype):
"""Cast this Block to use another data type.
Parameters
----------
dtype : str or numpy.dtype
The new data type.
"""
for child in self._children:
child.cast(dtype)
for _, param in self.params.items():
param.cast(dtype)

def __call__(self, *args):
"""Calls forward. Only accepts positional arguments."""
return self.forward(*args)
Expand Down Expand Up @@ -388,7 +401,6 @@ def _build_cache(self, *args):

def _finish_deferred_init(self, hybrid, *args):
self.infer_shape(*args)
self.infer_type(*args)
if hybrid:
for is_arg, i in self._cached_op_args:
if not is_arg:
Expand Down Expand Up @@ -429,6 +441,10 @@ def hybridize(self, active=True):
self._active = active
super(HybridBlock, self).hybridize(active)

def cast(self, dtype):
self._clear_cached_op()
super(HybridBlock, self).cast(dtype)

def _infer_attrs(self, infer_fn, attr, *args):
"""Generic infer attributes."""
inputs, out = self._get_graph(*args)
Expand Down
24 changes: 15 additions & 9 deletions python/mxnet/gluon/nn/basic_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
'Dropout', 'BatchNorm', 'LeakyReLU', 'Embedding', 'Flatten',
'Lambda', 'HybridLambda']
import warnings
import numpy as np

from ..block import Block, HybridBlock
from ..utils import _indent
Expand Down Expand Up @@ -185,11 +186,11 @@ def __init__(self, units, activation=None, use_bias=True, flatten=True,
self._units = units
self._in_units = in_units
self.weight = self.params.get('weight', shape=(units, in_units),
dtype=None, init=weight_initializer,
init=weight_initializer,
allow_deferred_init=True)
if use_bias:
self.bias = self.params.get('bias', shape=(units,),
dtype=None, init=bias_initializer,
init=bias_initializer,
allow_deferred_init=True)
else:
self.bias = None
Expand Down Expand Up @@ -336,24 +337,29 @@ def __init__(self, axis=1, momentum=0.9, epsilon=1e-5, center=True, scale=True,
self.in_channels = in_channels

self.gamma = self.params.get('gamma', grad_req='write' if scale else 'null',
shape=(in_channels,), dtype=None,
init=gamma_initializer, allow_deferred_init=True,
shape=(in_channels,), init=gamma_initializer,
allow_deferred_init=True,
differentiable=scale)
self.beta = self.params.get('beta', grad_req='write' if center else 'null',
shape=(in_channels,), dtype=None,
init=beta_initializer, allow_deferred_init=True,
shape=(in_channels,), init=beta_initializer,
allow_deferred_init=True,
differentiable=center)
self.running_mean = self.params.get('running_mean', grad_req='null',
shape=(in_channels,), dtype=None,
shape=(in_channels,),
init=running_mean_initializer,
allow_deferred_init=True,
differentiable=False)
self.running_var = self.params.get('running_var', grad_req='null',
shape=(in_channels,), dtype=None,
shape=(in_channels,),
init=running_variance_initializer,
allow_deferred_init=True,
differentiable=False)

def cast(self, dtype):
if np.dtype(dtype).name == 'float16':
dtype = 'float32'
super(BatchNorm, self).cast(dtype)

def hybrid_forward(self, F, x, gamma, beta, running_mean, running_var):
return F.BatchNorm(x, gamma, beta, running_mean, running_var,
name='fwd', **self._kwargs)
Expand Down Expand Up @@ -437,7 +443,7 @@ def __init__(self, input_dim, output_dim, dtype='float32',
self._kwargs = {'input_dim': input_dim, 'output_dim': output_dim,
'dtype': dtype}
self.weight = self.params.get('weight', shape=(input_dim, output_dim),
dtype=None, init=weight_initializer,
init=weight_initializer,
allow_deferred_init=True)

def hybrid_forward(self, F, x, weight):
Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/gluon/nn/conv_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,11 @@ def __init__(self, channels, kernel_size, strides, padding, dilation,
dshape[layout.find('C')] = in_channels
wshapes = _infer_weight_shape(op_name, dshape, self._kwargs)
self.weight = self.params.get('weight', shape=wshapes[1],
dtype=None, init=weight_initializer,
init=weight_initializer,
allow_deferred_init=True)
if use_bias:
self.bias = self.params.get('bias', shape=wshapes[2],
dtype=None, init=bias_initializer,
init=bias_initializer,
allow_deferred_init=True)
else:
self.bias = None
Expand Down
81 changes: 62 additions & 19 deletions python/mxnet/gluon/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ def __init__(self, name, grad_req='write', shape=None, dtype=mx_real_t,
self._differentiable = differentiable
self._allow_deferred_init = allow_deferred_init
self._grad_req = None
self._shape = shape
self.name = name
self.shape = shape
self.dtype = dtype
self.lr_mult = lr_mult
self.wd_mult = wd_mult
Expand Down Expand Up @@ -138,6 +138,23 @@ def grad_req(self, req):
elif self._data is not None:
self._init_grad()

@property
def shape(self):
return self._shape

@shape.setter
def shape(self, new_shape):
if self._shape is None:
self._shape = new_shape
return

assert len(self._shape) == len(new_shape) and \
all(j == 0 or i == j for i, j in zip(new_shape, self._shape)), \
"Expected shape %s is incompatible with given shape %s."%(
str(new_shape), str(self._shape))

self._shape = new_shape

def _check_and_get(self, arr_list, ctx):
if arr_list is not None:
if ctx is list:
Expand All @@ -147,9 +164,12 @@ def _check_and_get(self, arr_list, ctx):
return arr_list[0]
else:
ctx = context.current_context()
idx = self._ctx_map[ctx.device_typeid][ctx.device_id]
if idx is not None:
return arr_list[idx]
if ctx.device_typeid < len(self._ctx_map):
ctx_list = self._ctx_map[ctx.device_typeid]
if ctx.device_id < len(ctx_list):
idx = ctx_list[ctx.device_id]
if idx is not None:
return arr_list[idx]
raise RuntimeError(
"Parameter %s was not initialized on context %s. "
"It was only initialized on %s."%(
Expand Down Expand Up @@ -203,7 +223,7 @@ def _finish_deferred_init(self):
"""Finishes deferred initialization."""
if not self._deferred_init:
return
init, ctx, default_init = self._deferred_init
init, ctx, default_init, data = self._deferred_init
self._deferred_init = ()
assert self.shape is not None and np.prod(self.shape) > 0, \
"Cannot initialize Parameter %s because it has " \
Expand All @@ -212,10 +232,11 @@ def _finish_deferred_init(self):
self.name, str(self.shape))

with autograd.pause():
data = ndarray.zeros(shape=self.shape, dtype=self.dtype,
ctx=context.cpu())
initializer.create(default_init)(
initializer.InitDesc(self.name, {'__init__': init}), data)
if data is None:
data = ndarray.zeros(shape=self.shape, dtype=self.dtype,
ctx=context.cpu())
initializer.create(default_init)(
initializer.InitDesc(self.name, {'__init__': init}), data)

self._init_impl(data, ctx)

Expand Down Expand Up @@ -306,14 +327,14 @@ def initialize(self, init=None, ctx=None, default_init=initializer.Uniform(),
ctx = [ctx]
if init is None:
init = default_init if self.init is None else self.init
if self.dtype is None or not self.shape or np.prod(self.shape) <= 0:
if not self.shape or np.prod(self.shape) <= 0:
if self._allow_deferred_init:
self._deferred_init = (init, ctx, default_init)
self._deferred_init = (init, ctx, default_init, None)
return
raise ValueError("Cannot initialize Parameter %s because it has " \
"invalid shape: %s."%(self.name, str(self.shape)))

self._deferred_init = (init, ctx, default_init)
self._deferred_init = (init, ctx, default_init, None)
self._finish_deferred_init()

def reset_ctx(self, ctx):
Expand All @@ -332,21 +353,25 @@ def reset_ctx(self, ctx):
with autograd.pause():
self._init_impl(data, ctx)
elif self._deferred_init:
init, _, default_init = self._deferred_init
self._deferred_init = (init, ctx, default_init)
init, _, default_init, data = self._deferred_init
self._deferred_init = (init, ctx, default_init, data)
else:
raise ValueError("Cannot reset context for Parameter %s because it "
"has not been initialized."%self.name)


def set_data(self, data):
"""Sets this parameter's value on all contexts to data."""
assert self._data is not None, \
"Parameter %s has not been initialized"%self.name
"""Sets this parameter's value on all contexts."""
self.shape = data.shape

if self._data is None:
assert self._deferred_init is not None, \
"Parameter %s has not been initialized"%self.name
self._deferred_init = self._deferred_init[:3] + (data,)
return

for arr in self.list_data():
arr[:] = data
if not self.shape or np.prod(self.shape) <= 0:
self.shape = data.shape

def data(self, ctx=None):
"""Returns a copy of this parameter on one context. Must have been
Expand Down Expand Up @@ -415,6 +440,24 @@ def var(self):
init=self.init)
return self._var

def cast(self, dtype):
"""Cast data and gradient of this Parameter to a new data type.
Parameters
----------
dtype : str or numpy.dtype
The new data type.
"""
self.dtype = dtype
if self._data is None:
return
with autograd.pause():
self._data = [i.astype(dtype) for i in self._data]
if self._grad is None:
return
self._grad = [i.astype(dtype) for i in self._grad]
autograd.mark_variables(self._data, self._grad, self.grad_req)


class ParameterDict(object):
"""A dictionary managing a set of parameters.
Expand Down
24 changes: 12 additions & 12 deletions python/mxnet/gluon/rnn/rnn_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,16 +326,16 @@ def __init__(self, hidden_size, activation='tanh',
self._activation = activation
self._input_size = input_size
self.i2h_weight = self.params.get('i2h_weight', shape=(hidden_size, input_size),
dtype=None, init=i2h_weight_initializer,
init=i2h_weight_initializer,
allow_deferred_init=True)
self.h2h_weight = self.params.get('h2h_weight', shape=(hidden_size, hidden_size),
dtype=None, init=h2h_weight_initializer,
init=h2h_weight_initializer,
allow_deferred_init=True)
self.i2h_bias = self.params.get('i2h_bias', shape=(hidden_size,),
dtype=None, init=i2h_bias_initializer,
init=i2h_bias_initializer,
allow_deferred_init=True)
self.h2h_bias = self.params.get('h2h_bias', shape=(hidden_size,),
dtype=None, init=h2h_bias_initializer,
init=h2h_bias_initializer,
allow_deferred_init=True)

def state_info(self, batch_size=0):
Expand Down Expand Up @@ -434,16 +434,16 @@ def __init__(self, hidden_size,
self._hidden_size = hidden_size
self._input_size = input_size
self.i2h_weight = self.params.get('i2h_weight', shape=(4*hidden_size, input_size),
dtype=None, init=i2h_weight_initializer,
init=i2h_weight_initializer,
allow_deferred_init=True)
self.h2h_weight = self.params.get('h2h_weight', shape=(4*hidden_size, hidden_size),
dtype=None, init=h2h_weight_initializer,
init=h2h_weight_initializer,
allow_deferred_init=True)
self.i2h_bias = self.params.get('i2h_bias', shape=(4*hidden_size,),
dtype=None, init=i2h_bias_initializer,
init=i2h_bias_initializer,
allow_deferred_init=True)
self.h2h_bias = self.params.get('h2h_bias', shape=(4*hidden_size,),
dtype=None, init=h2h_bias_initializer,
init=h2h_bias_initializer,
allow_deferred_init=True)

def state_info(self, batch_size=0):
Expand Down Expand Up @@ -541,16 +541,16 @@ def __init__(self, hidden_size,
self._hidden_size = hidden_size
self._input_size = input_size
self.i2h_weight = self.params.get('i2h_weight', shape=(3*hidden_size, input_size),
dtype=None, init=i2h_weight_initializer,
init=i2h_weight_initializer,
allow_deferred_init=True)
self.h2h_weight = self.params.get('h2h_weight', shape=(3*hidden_size, hidden_size),
dtype=None, init=h2h_weight_initializer,
init=h2h_weight_initializer,
allow_deferred_init=True)
self.i2h_bias = self.params.get('i2h_bias', shape=(3*hidden_size,),
dtype=None, init=i2h_bias_initializer,
init=i2h_bias_initializer,
allow_deferred_init=True)
self.h2h_bias = self.params.get('h2h_bias', shape=(3*hidden_size,),
dtype=None, init=h2h_bias_initializer,
init=h2h_bias_initializer,
allow_deferred_init=True)

def state_info(self, batch_size=0):
Expand Down
12 changes: 10 additions & 2 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,12 +570,20 @@ def test_fill_shape_deferred():
def test_dtype():
net = mx.gluon.model_zoo.vision.resnet18_v1()
net.initialize()
net(mx.nd.ones((16, 3, 32, 32), dtype='float64')).wait_to_read()
net.cast('float64')
with mx.autograd.record():
y = net(mx.nd.ones((16, 3, 32, 32), dtype='float64'))
y.backward()

net = mx.gluon.model_zoo.vision.resnet18_v1()
net.initialize()
net.hybridize()
net(mx.nd.ones((16, 3, 32, 32), dtype='float64')).wait_to_read()
net(mx.nd.ones((16, 3, 32, 32), dtype='float32'))

net.cast('float64')
net(mx.nd.ones((16, 3, 32, 32), dtype='float64'))

mx.nd.waitall()


def test_fill_shape_load():
Expand Down

0 comments on commit 1852e2f

Please sign in to comment.