Skip to content

Commit c48c6d6

Browse files
committed
SINGA-386 Implement RNN operation for autograd
- redesign some RNN related functions and their APIs. - Now the design of RNN operation is for mini-batch train. - related files can be built without error.
1 parent 95b4377 commit c48c6d6

File tree

4 files changed

+123
-99
lines changed

4 files changed

+123
-99
lines changed

python/singa/autograd.py

+15-9
Original file line numberDiff line numberDiff line change
@@ -966,8 +966,8 @@ def forward(self, X, h0, c0, W):
966966
# hout_cout: (hout, cout) if lstm, else (hout,)
967967
# hout, cout of shape (num_layers * num_directions, batch,
968968
# hidden_size)
969-
oututs= 1dTo3d(Y)
970-
969+
oututs= _1dTo3d(Y)
970+
971971
if self.rnn_mode != 'lstm':
972972
return outputs, hout
973973
else:
@@ -977,7 +977,7 @@ def backward(self, dY, dh, dc=CTensor([])):
977977
assert training is True and hasattr(
978978
self, 'cache'), 'Please set training as True before do BP. '
979979

980-
dY_1d= 3dTo1d(dY)
980+
dY_1d= _3dTo1d(dY)
981981

982982
if dY_1d.device().id() != self.handle.device_id:
983983
dY_1d.ToDevice(self.cache[0].device())
@@ -988,7 +988,7 @@ def backward(self, dY, dh, dc=CTensor([])):
988988
dX_1d, dhout, dcout, dW = singa.GpuRNNBackward(
989989
self.handle, dY_1d, dh, dc, self.cache)
990990

991-
dX = 1dTo3d(dX_1d)
991+
dX = _1dTo3d(dX_1d)
992992

993993
if self.rnn_mode != 'lstm':
994994
return dX, dhout, dW
@@ -1038,7 +1038,7 @@ def __init__(self, input_size, hidden_size, num_layers=1, bias=True, batch_first
10381038
W_Size *= mult * w_size
10391039

10401040
self.W_Size = W_Size
1041-
self.W = Tensor(shape=(W_Size,), requires_grad=True, stores_grad=True)
1041+
self.W = Tensor(shape=(W_Size,), requires_grad=True, stores_grad=True) # TODO: assign value of Wi separately
10421042
self.W.uniform(0.0, 1.0)
10431043

10441044
def __call__(self, inputs, h0, c0=None):
@@ -1052,17 +1052,23 @@ def __call__(self, inputs, h0, c0=None):
10521052
assert c0 is not None, 'Please input c0.'
10531053
self.device_check(h0, c0)
10541054

1055-
self.handle = signa.CudnnRNNHandle(inputs.data, *SOME_PARAMETERS*)
1055+
if not hasattr(self, 'handle'):
1056+
self.handle = signa.CudnnRNNHandle(inputs.data, self.input_size, self.hidden_size, self.num_layers,
1057+
self.rnn_mode, self.dropout, self.bidirectional, self.W_Size)
1058+
elif inputs.shape[0] != self.handle.seq_length_ or inputs.shape[1] != self.handle.batch_size_:
1059+
self.handle = signa.CudnnRNNHandle(inputs.data, self.input_size, self.hidden_size, self.num_layers,
1060+
self.rnn_mode, self.dropout, self.bidirectional, self.W_Size)
1061+
10561062
self.handle.device_id = inputs.device.id()
10571063

1058-
X= 3dTo1d(inputs)
1064+
X= _3dTo1d(inputs)
10591065
outputs = rnn(self.handle, X, h0, c0, self.W)
10601066
return outputs
10611067

1062-
def 3dTo1d(self, inputs):
1068+
def _3dTo1d(self, inputs):
10631069
pass
10641070

1065-
def 1dTo3d(self, *args):
1071+
def _1dTo3d(self, *args):
10661072
pass
10671073

10681074
class LSTM(RNN):

src/api/model_operation.i

+27-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#include "../src/model/operation/convolution.h"
88
#include "../src/model/operation/batchnorm.h"
99
#include "../src/model/operation/pooling.h"
10-
10+
#include "../src/model/operation/rnn.h"
1111
%}
1212

1313
namespace singa {
@@ -51,6 +51,14 @@ class PoolingHandle {
5151
int pooled_width;
5252
};
5353

54+
class RNNHandle {
55+
public:
56+
RNNHandle(const Tensor &input, const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks,
57+
const std::string Rnn_mode, const float Dropout, const bool bidirectional, const size_t Weight_size);
58+
59+
size_t batch_size_;
60+
size_t seq_length_;
61+
};
5462

5563
#if USE_CUDNN
5664
class CudnnConvHandle: public ConvHandle {
@@ -106,6 +114,24 @@ Tensor GpuPoolingForward(const CudnnPoolingHandle &cph, const Tensor &x);
106114

107115
Tensor GpuPoolingBackward(const CudnnPoolingHandle &cph, const Tensor &dy, const Tensor& x, const Tensor& y);
108116

117+
118+
class CudnnRNNHandle: public RNNHandle {
119+
public:
120+
CudnnRNNHandle(const Tensor &input, const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks,
121+
const std::string Rnn_mode, const float Dropout, const bool bidirectional, const size_t Weight_size);
122+
123+
size_t batch_size_;
124+
size_t seq_length_;
125+
126+
};
127+
128+
std::vector<Tensor> GpuRNNForwardTraining(const CudnnRNNHandle &crh, const Tensor &input, const Tensor &hx, const Tensor &cx, const Tensor &W) ;
129+
130+
std::vector<Tensor> GpuRNNForwardInference(const CudnnRNNHandle &crh, const Tensor &input, const Tensor &hx, const Tensor &cx, const Tensor &W);
131+
132+
std::vector<Tensor> GpuRNNBackward(const CudnnRNNHandle &crh, const Tensor &dY, const Tensor &dhy, const Tensor &dcy, const std::vector<Tensor> &cache);
133+
134+
109135
#endif // USE_CUDNN
110136

111137
} //namespace singa

src/model/operation/rnn.cc

+69-78
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,12 @@
22

33
namespace singa {
44

5-
RNNHandle::RNNHandle(const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks,
6-
const std::string Rnn_mode, const float Dropout, const bool bidirectional) {
5+
RNNHandle::RNNHandle(const Tensor &input, const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks,
6+
const std::string Rnn_mode, const float Dropout, const bool bidirectional, const size_t Weight_size) {
7+
8+
CHECK_EQ(input.shape(2), Input_size);
9+
batch_size_ = input.shape(1);
10+
seq_length_= input.shape(0);
711

812
input_size_ = Input_size;
913
CHECK_GT(input_size_, 0u);
@@ -28,68 +32,62 @@ RNNHandle::RNNHandle(const size_t Input_size, const size_t Hidden_size, const si
2832
}
2933
// the first constant (4) is the size of float
3034
// the second constant (2, 8, 6) is the number of sets of params
31-
int mult = 1;
32-
if (rnn_mode_ == "relu" || rnn_mode_ == "tanh")
33-
mult *= 1;
34-
else if (rnn_mode_ == "lstm")
35-
mult *= 4;
36-
else if (rnn_mode_ == "gru")
37-
mult *= 3;
38-
if (bidirectional)
39-
mult *= 2;
40-
41-
weight_size = 0;
42-
for (size_t i = 0; i < num_stacks_; i++) {
43-
size_t dim = hidden_size_ * (input_size_ + hidden_size_ + 2);
44-
if (i > 0)
45-
dim = hidden_size_ * (hidden_size_ + hidden_size_ + 2);
46-
weight_size += mult * dim;
47-
}
35+
weight_size= Weight_size;
36+
4837
};
4938

5039
#ifdef USE_CUDNN
5140

52-
CudnnRNNHandle::CudnnRNNHandle(const vector<Tensor> &inputs, const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks,
53-
const std::string Rnn_mode, const float Dropout, const bool bidirectional):
54-
RNNHandle(Input_size, Hidden_size, Num_stacks, Rnn_mode, Dropout, bidirectional) {
41+
CudnnRNNHandle::CudnnRNNHandle(const Tensor &input, const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks,
42+
const std::string Rnn_mode, const float Dropout, const bool bidirectional, const size_t Weight_size):
43+
RNNHandle(input, Input_size, Hidden_size, Num_stacks, Rnn_mode, Dropout, bidirectional, Weight_size) {
5544

56-
CHECK_GT(inputs.size(), 1u + has_cell_);
57-
size_t num_x = inputs.size() - has_cell_ - 1;
58-
59-
DataType dtype = inputs.at(0).data_type();
60-
if (rnn_desc_ != nullptr)
61-
CHECK_EQ(dtype_, GetCudnnDataType(dtype))
62-
<< "Cannot change cudnn data type during training from " << dtype_
63-
<< " to " << GetCudnnDataType(dtype);
64-
else
65-
dtype_ = GetCudnnDataType(dtype);
45+
DataType dtype = input.data_type();
46+
dtype_ = GetCudnnDataType(dtype);
6647

67-
UpdateStates(num_x, inputs);
48+
UpdateIODescriptors(input);
49+
ResetHiddenAndCellDescriptors();
50+
SetRNNDescriptor(input.device());
51+
UpdateSpaces(seq_length_, input.device());
6852
};
6953

70-
void CudnnRNNHandle::UpdateStates(size_t num_x, const vector<Tensor> &inputs) {
71-
UpdateIODescriptors(num_x, inputs);
72-
size_t new_batch_size = inputs.at(0).shape(0);
73-
if (batch_size_ != new_batch_size)
74-
ResetHiddenAndCellDescriptors(new_batch_size);
75-
if (rnn_desc_ == nullptr)
76-
SetRNNDescriptor(inputs.at(0).device());
77-
UpdateSpaces(num_x, inputs.at(0).device());
78-
batch_size_ = new_batch_size;
79-
seq_length_ = num_x;
54+
CudnnRNNHandle::~CudnnRNNHandle() {
55+
if (weight_desc_ != nullptr)
56+
CUDNN_CHECK(cudnnDestroyFilterDescriptor(weight_desc_));
57+
if (dropout_desc_ != nullptr)
58+
CUDNN_CHECK(cudnnDestroyDropoutDescriptor(dropout_desc_));
59+
if (rnn_desc_ != nullptr)
60+
CUDNN_CHECK(cudnnDestroyRNNDescriptor(rnn_desc_));
61+
if (hx_desc_ != nullptr)
62+
CUDNN_CHECK(cudnnDestroyTensorDescriptor(hx_desc_));
63+
if (hy_desc_ != nullptr)
64+
CUDNN_CHECK(cudnnDestroyTensorDescriptor(hy_desc_));
65+
if (cx_desc_ != nullptr)
66+
CUDNN_CHECK(cudnnDestroyTensorDescriptor(cx_desc_));
67+
if (cy_desc_ != nullptr)
68+
CUDNN_CHECK(cudnnDestroyTensorDescriptor(cy_desc_));
69+
if (dhx_desc_ != nullptr)
70+
CUDNN_CHECK(cudnnDestroyTensorDescriptor(dhx_desc_));
71+
if (dhy_desc_ != nullptr)
72+
CUDNN_CHECK(cudnnDestroyTensorDescriptor(dhy_desc_));
73+
if (dcx_desc_ != nullptr)
74+
CUDNN_CHECK(cudnnDestroyTensorDescriptor(dcx_desc_));
75+
if (dcy_desc_ != nullptr)
76+
CUDNN_CHECK(cudnnDestroyTensorDescriptor(dcy_desc_));
77+
DestroyIODescriptors();
8078
};
8179

8280
void CudnnRNNHandle::DestroyIODescriptors() {
8381
if (x_descs_ != nullptr) {
84-
for (size_t i = 0; i < max_length_; i++) {
82+
for (size_t i = 0; i < seq_length_; i++) {
8583
CUDNN_CHECK(cudnnDestroyTensorDescriptor(x_descs_[i]));
8684
CUDNN_CHECK(cudnnDestroyTensorDescriptor(dx_descs_[i]));
8785
}
8886
delete [] x_descs_;
8987
delete [] dx_descs_;
9088
}
9189
if (y_descs_ != nullptr) {
92-
for (size_t i = 0; i < max_length_; i++) {
90+
for (size_t i = 0; i < seq_length_; i++) {
9391
CUDNN_CHECK(cudnnDestroyTensorDescriptor(y_descs_[i]));
9492
CUDNN_CHECK(cudnnDestroyTensorDescriptor(dy_descs_[i]));
9593
}
@@ -98,61 +96,60 @@ void CudnnRNNHandle::DestroyIODescriptors() {
9896
}
9997
};
10098

101-
void CudnnRNNHandle::UpdateIODescriptors(size_t len, const vector<Tensor> &inputs) {
102-
bool reset = false;
103-
if (max_length_ < len) {
104-
DestroyIODescriptors();
105-
max_length_ = len;
106-
x_descs_ = new cudnnTensorDescriptor_t[len];
107-
dx_descs_ = new cudnnTensorDescriptor_t[len];
108-
y_descs_ = new cudnnTensorDescriptor_t[len];
109-
dy_descs_ = new cudnnTensorDescriptor_t[len];
110-
for (size_t i = 0; i < len; i++) {
99+
100+
void CudnnRNNHandle::UpdateIODescriptors(const Tensor &input) {
101+
x_descs_ = new cudnnTensorDescriptor_t[seq_length_];
102+
dx_descs_ = new cudnnTensorDescriptor_t[seq_length_];
103+
y_descs_ = new cudnnTensorDescriptor_t[seq_length_];
104+
dy_descs_ = new cudnnTensorDescriptor_t[seq_length_];
105+
for (size_t i = 0; i < seq_length_; i++) {
111106
CUDNN_CHECK(cudnnCreateTensorDescriptor(&x_descs_[i]));
112107
CUDNN_CHECK(cudnnCreateTensorDescriptor(&dx_descs_[i]));
113108
CUDNN_CHECK(cudnnCreateTensorDescriptor(&y_descs_[i]));
114109
CUDNN_CHECK(cudnnCreateTensorDescriptor(&dy_descs_[i]));
115110
}
116-
reset = true;
117-
}
118111

119-
for (size_t i = 0; i < len; i++) {
120-
CHECK_EQ(inputs[i].shape(1), input_size_);
121-
if (inputs[i].shape(0) != batch_size_ || reset) {
112+
for (size_t i = 0; i < seq_length_; i++) {
113+
CHECK_EQ(input.shape(2), input_size_);
122114
int d[3] = {1, 1, 1}, s[3] = {1, 1, 1};
123-
d[0] = static_cast<int>(inputs[i].shape(0));
115+
d[0] = static_cast<int>(batch_size_);
124116
CHECK_GT(d[0], 0);
125-
d[1] = static_cast<int>(inputs[i].shape(1));
117+
d[1] = static_cast<int>(input_size_);
126118
s[0] = d[1] * d[2];
127119
s[1] = d[2];
128120
CUDNN_CHECK(cudnnSetTensorNdDescriptor(x_descs_[i], dtype_, 3, d, s));
129121
CUDNN_CHECK(cudnnSetTensorNdDescriptor(dx_descs_[i], dtype_, 3, d, s));
130122

131-
d[0] = static_cast<int>(inputs[i].shape(0));
123+
d[0] = static_cast<int>(batch_size_);
132124
d[1] = static_cast<int>(hidden_size_ * num_directions_);
133125
s[0] = d[1] * d[2];
134126
s[1] = d[2];
135127
CUDNN_CHECK(cudnnSetTensorNdDescriptor(y_descs_[i], dtype_, 3, d, s));
136128
CUDNN_CHECK(cudnnSetTensorNdDescriptor(dy_descs_[i], dtype_, 3, d, s));
137129
}
138-
}
139130
};
140131

141-
void CudnnRNNHandle::ResetHiddenAndCellDescriptors(size_t batch_size) {
142-
if (batch_size_ == 0) {
132+
void CudnnRNNHandle::ResetHiddenAndCellDescriptors() {
133+
if (cx_desc_ == nullptr)
143134
CUDNN_CHECK(cudnnCreateTensorDescriptor(&cx_desc_));
135+
if (dcx_desc_ == nullptr)
144136
CUDNN_CHECK(cudnnCreateTensorDescriptor(&dcx_desc_));
137+
if (cy_desc_ == nullptr)
145138
CUDNN_CHECK(cudnnCreateTensorDescriptor(&cy_desc_));
139+
if (dcy_desc_ == nullptr)
146140
CUDNN_CHECK(cudnnCreateTensorDescriptor(&dcy_desc_));
141+
if (hx_desc_ == nullptr)
147142
CUDNN_CHECK(cudnnCreateTensorDescriptor(&hx_desc_));
143+
if (dhx_desc_ == nullptr)
148144
CUDNN_CHECK(cudnnCreateTensorDescriptor(&dhx_desc_));
145+
if (hy_desc_ == nullptr)
149146
CUDNN_CHECK(cudnnCreateTensorDescriptor(&hy_desc_));
147+
if (dhy_desc_ == nullptr)
150148
CUDNN_CHECK(cudnnCreateTensorDescriptor(&dhy_desc_));
151-
}
152149

153150
int dim[3] = {1, 1, 1};
154151
dim[0] = static_cast<int>(num_stacks_ * num_directions_);
155-
dim[1] = static_cast<int>(batch_size);
152+
dim[1] = static_cast<int>(batch_size_);
156153
dim[2] = static_cast<int>(hidden_size_);
157154
int stride[3] = {1, 1, 1};
158155
stride[0] = dim[1] * dim[2];
@@ -229,7 +226,7 @@ void CudnnRNNHandle::UpdateSpaces(size_t seq_length, shared_ptr<Device> dev) {
229226
reserve_space_ = Tensor(Shape{count}, dev, kChar);
230227
// reserve_space_.SetValue(0);
231228
}
232-
}
229+
};
233230

234231
Tensor MergeInputs(size_t num, const vector<Tensor> &in) {
235232
if (num == 1)
@@ -265,15 +262,14 @@ vector<Tensor> SplitOutput(size_t num, size_t dim,
265262

266263
std::vector<Tensor> GpuRNNForwardTraining(const CudnnRNNHandle &crh, const Tensor &input, const Tensor &hx, const Tensor &cx, const Tensor &W) {
267264
DataType dtype = input.data_type();
268-
auto dev = input.at(0).device();
265+
auto dev = input.device();
269266

270267

271268
Shape outshape{input.Size() * crh.hidden_size_ / crh.input_size_ * crh.num_directions_};
272269
Tensor output(outshape, dev, dtype);
273270
// LOG(INFO) << "output size " << output.Size();
274271

275272
Shape state_shape{crh.num_stacks_ * crh.num_directions_, crh.batch_size_, crh.hidden_size_};
276-
CHECK_EQ(hx.shape(), state_shape);
277273
Tensor hy(state_shape, dev, dtype);
278274

279275
Tensor cy;
@@ -339,7 +335,6 @@ std::vector<Tensor> GpuRNNForwardInference(const CudnnRNNHandle &crh, const Tens
339335
// LOG(INFO) << "output size " << output.Size();
340336

341337
Shape state_shape{crh.num_stacks_ * crh.num_directions_, crh.batch_size_, crh.hidden_size_};
342-
CHECK_EQ(hx.shape(), state_shape);
343338
Tensor hy(state_shape, dev, dtype);
344339

345340
Tensor cy;
@@ -389,7 +384,7 @@ std::vector<Tensor> GpuRNNForwardInference(const CudnnRNNHandle &crh, const Tens
389384
return {output, hy, cy};
390385
};
391386

392-
std::vector<Tensor> GpuRNNBackward(const CudnnRNNHandle &crh, const vector<Tensor> &dY, const Tensor &dh, const Tensor &dc, const vector<Tensor> &cache) {
387+
std::vector<Tensor> GpuRNNBackward(const CudnnRNNHandle &crh, const Tensor &dY, const Tensor &dhy, const Tensor &dcy, const std::vector<Tensor> &cache) {
393388
const Tensor x = cache[0];
394389
const Tensor y = cache[1];
395390
const Tensor hx = cache[2];
@@ -399,26 +394,22 @@ std::vector<Tensor> GpuRNNBackward(const CudnnRNNHandle &crh, const vector<Tenso
399394
auto dev = y.device();
400395
auto dtype = y.data_type();
401396

402-
403397
CHECK_EQ(dY.Size(), y.Size());
404398

405-
406399
Shape xshape{y.Size() * crh.input_size_ / crh.hidden_size_ / crh.num_directions_};
407-
CHECK_EQ(x.shape(), xshape)
408400
Tensor dx(xshape, dev, dtype);
409401

410402
Tensor dw(W.shape(), dev, dtype);
411403

412404
Shape state_shape{crh.num_stacks_ * crh.num_directions_, crh.batch_size_, crh.hidden_size_};
413-
CHECK_EQ(hx.shape(), state_shape)
414405
Tensor dhx(state_shape, dev, dtype);
415406

416407
Tensor dcx;
417408
if (crh.has_cell_)
418409
dcx.ResetLike(dhx);
419410

420411
dw.SetValue(0.0f);
421-
Block *yb = y.block(), *dyb = dy.block(), *dhyb = dhy.block(),
412+
Block *yb = y.block(), *dyb = dY.block(), *dhyb = dhy.block(),
422413
*dcyb = dcy.block(), *xb = x.block(), *cxb = cx.block(),
423414
*wb = W.block(), *dwb = dw.block(), *hxb = hx.block(),
424415
*dxb = dx.block(), *dhxb = dhx.block(), *dcxb = dcx.block(),

0 commit comments

Comments
 (0)