Skip to content

Commit

Permalink
SINGA-386 Implement RNN operation for autograd
Browse files Browse the repository at this point in the history
- fix bugs in cpp parts, the codes can be made without error.
  • Loading branch information
xuewanqi committed Jul 17, 2018
1 parent 33ddc2d commit c6957b7
Show file tree
Hide file tree
Showing 3 changed files with 266 additions and 126 deletions.
110 changes: 67 additions & 43 deletions python/singa/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,19 +347,6 @@ def add_bias(x, b, axis=0):
return AddBias(axis)(x, b)[0]


class Add(Operation):

def forward(self, a, b):
return singa.__add__(a, b)

def backward(self, dy):
return dy, dy


def add(a, b):
return Add()(a, b)[0]


class SoftMax(Operation):
'''
Apply SoftMax for each row of the Tensor or each column of the Tensor
Expand Down Expand Up @@ -469,24 +456,22 @@ def cross_entropy(y, t):

class SoftMaxCrossEntropy(Operation):

def __init__(self, t):
self.t = t.data

def forward(self, x):
def forward(self, x, t):
self.p = singa.SoftMax(x)
self.t = t
loss = CTensor((1,), self.p.device())
ret = singa.CrossEntropyFwd(self.p, self.t)
ret = singa.CrossEntropyFwd(self.p, t)
loss.SetFloatValue(singa.SumAsFloat(ret) / x.shape()[0])
return loss

def backward(self, dy=1.0):
dx = singa.SoftmaxCrossEntropyBwd(self.p, self.t)
return singa.DivFloat(dx, float(self.p.shape()[0]))
return singa.DivFloat(dx, float(self.p.shape()[0])), None


def softmax_cross_entropy(x, t):
# x is the logits and t is the ground truth; both are 2D.
return SoftMaxCrossEntropy(t)(x)[0]
return SoftMaxCrossEntropy()(x, t)[0]


def ctensor2numpy(x):
Expand Down Expand Up @@ -587,12 +572,12 @@ def backward(self, dy):
return tuple(dxs)


def cat(xs, axis=0):
def concat(xs, axis=0):
# xs is a tuple of multiple Tensors
return Concat(axis)(*xs)[0]


class _Conv2d(Operation):
class _Conv2D(Operation):

def __init__(self, handle):
self.handle = handle
Expand Down Expand Up @@ -642,10 +627,10 @@ def backward(self, dy):


def conv2d(handle, x, W, b):
return _Conv2d(handle)(x, W, b)[0]
return _Conv2D(handle)(x, W, b)[0]


class Conv2d(Layer):
class Conv2D(Layer):

def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True, **kwargs):
Expand Down Expand Up @@ -708,6 +693,10 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1,

def __call__(self, x):
assert x.shape[1] == self.in_channels, 'in_channels dismatched'
assert (x.shape[2] + 2 * self.padding[0] - self.kernel_size[0]
) % self.stride[0] == 0, 'invalid padding or strides.'
assert (x.shape[3] + 2 * self.padding[1] - self.kernel_size[1]
) % self.stride[1] == 0, 'invalid padding or stride.'

self.device_check(x, self.W, self.b)

Expand All @@ -731,7 +720,7 @@ def __call__(self, x):
return y


class BatchNorm2d(Layer):
class BatchNorm(Layer):

def __init__(self, num_features, momentum=0.9):
self.channels = num_features
Expand Down Expand Up @@ -771,12 +760,12 @@ def __call__(self, x):
self.momentum, x.data)
self.handle.device_id = x.device.id()

y = batchnorm_2d(self.handle, x, self.scale, self.bias,
y = batchnorm(self.handle, x, self.scale, self.bias,
self.running_mean, self.running_var)
return y


class _BatchNorm2d(Operation):
class _BatchNorm(Operation):

def __init__(self, handle, running_mean, running_var):
self.running_mean = running_mean.data
Expand All @@ -796,7 +785,7 @@ def forward(self, x, scale, bias):
if self.handle.device_id == -1:
raise NotImplementedError
else:
y = singa.GpuBatchNormForwardInference(
y, _, _ = singa.GpuBatchNormForwardInference(
self.handle, x, scale, bias, self.running_mean, self.running_var)
return y

Expand All @@ -816,11 +805,11 @@ def backward(self, dy):
return dx, ds, db


def batchnorm_2d(handle, x, scale, bias, running_mean, running_var):
return _BatchNorm2d(handle, running_mean, running_var)(x, scale, bias)[0]
def batchnorm(handle, x, scale, bias, running_mean, running_var):
return _BatchNorm(handle, running_mean, running_var)(x, scale, bias)[0]


class _Pooling2d(Operation):
class _Pooling2D(Operation):

def __init__(self, handle):
self.handle = handle
Expand All @@ -846,10 +835,10 @@ def backward(self, dy):


def pooling_2d(handle, x):
return _Pooling2d(handle)(x)[0]
return _Pooling2D(handle)(x)[0]


class Pooling2d(Layer):
class Pooling2D(Layer):

def __init__(self, kernel_size, stride=None, padding=0, is_max=True):
if isinstance(kernel_size, int):
Expand Down Expand Up @@ -896,43 +885,78 @@ def __call__(self, x):
else:
if not hasattr(self, 'handle'):
self.handle = singa.CudnnPoolingHandle(x.data, self.kernel_size, self.stride,
self.padding, self.is_max)
self.padding, self.is_max) # False for nan_prop
elif x.shape[0] != self.handle.batchsize or out_shape_h != self.handle.pooled_height or \
out_shape_w != self.handle.pooled_width:
self.handle = singa.CudnnPoolingHandle(x.data, self.kernel_size, self.stride,
self.padding, self.is_max)
self.padding, self.is_max) # False for nan_prop

self.handle.device_id = x.device.id()

y = pooling_2d(self.handle, x)
return y


class MaxPool2d(Pooling2d):
class MaxPooling2D(Pooling2D):

def __init__(self, kernel_size, stride=None, padding=0):
super(MaxPool2d, self).__init__(kernel_size, stride, padding, True)
super(MaxPooling2D, self).__init__(kernel_size, stride, padding, True)


class AvgPool2d(Pooling2d):
class AvgPooling2D(Pooling2D):

def __init__(self, kernel_size, stride=None, padding=0):
super(AvgPool2d, self).__init__(kernel_size, stride, padding, False)
super(AvgPooling2D, self).__init__(kernel_size, stride, padding, False)


class MaxPool1d(Pooling2d):
class MaxPooling1D(Pooling2D):

def __init__(self, kernel_size, stride=None, padding=0):
if stride is None:
stride = kernel_size
super(MaxPool2d, self).__init__(
super(MaxPooling2D, self).__init__(
(1, kernel_size), (0, stride), (0, padding), True)


class AvgPool1d(Pooling2d):
class AvgPooling1D(Pooling2D):

def __init__(self, kernel_size, stride=None, padding=0):
if stride is None:
stride = kernel_size
super(MaxPool2d, self).__init__(
super(MaxPooling2D, self).__init__(
(1, kernel_size), (0, stride), (0, padding), False)


class _RNN(Operation):
def __init__(self, handle):
self.handle= handle

def forward(self, X, W):

if self.handle.device_id ==-1:
raise NotImplementedError
else:
if training:
out, self.cache=singa.GpuRNNForwardTraining(self.handle, X, W)
else:
out=singa.GpuRNNForwardInference(self.handle, X, W)
return out

def backward(self, dY):
assert training is True and hasattr(
self, 'cache'), 'Please set training as True before do BP. '

if dY.device().id() != self.handle.device_id:
dY.ToDevice(self.inputs[0].device())

if self.handle.device_id == -1:
raise NotImplementedError
else:
dX, dW=singa.GpuRNNBackward(self.handle, dY, self.cache)
return dX, dW

def rnn():
pass

class RNN(Layer):

Loading

0 comments on commit c6957b7

Please sign in to comment.