2
2
3
3
namespace singa {
4
4
5
- RNNHandle::RNNHandle (const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks,
6
- const std::string Rnn_mode, const float Dropout, const bool bidirectional) {
5
+ RNNHandle::RNNHandle (const Tensor &input, const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks,
6
+ const std::string Rnn_mode, const float Dropout, const bool bidirectional, const size_t Weight_size) {
7
+
8
+ CHECK_EQ (input.shape (2 ), Input_size);
9
+ batch_size_ = input.shape (1 );
10
+ seq_length_= input.shape (0 );
7
11
8
12
input_size_ = Input_size;
9
13
CHECK_GT (input_size_, 0u );
@@ -28,68 +32,62 @@ RNNHandle::RNNHandle(const size_t Input_size, const size_t Hidden_size, const si
28
32
}
29
33
// the first constant (4) is the size of float
30
34
// the second constant (2, 8, 6) is the number of sets of params
31
- int mult = 1 ;
32
- if (rnn_mode_ == " relu" || rnn_mode_ == " tanh" )
33
- mult *= 1 ;
34
- else if (rnn_mode_ == " lstm" )
35
- mult *= 4 ;
36
- else if (rnn_mode_ == " gru" )
37
- mult *= 3 ;
38
- if (bidirectional)
39
- mult *= 2 ;
40
-
41
- weight_size = 0 ;
42
- for (size_t i = 0 ; i < num_stacks_; i++) {
43
- size_t dim = hidden_size_ * (input_size_ + hidden_size_ + 2 );
44
- if (i > 0 )
45
- dim = hidden_size_ * (hidden_size_ + hidden_size_ + 2 );
46
- weight_size += mult * dim;
47
- }
35
+ weight_size= Weight_size;
36
+
48
37
};
49
38
50
39
#ifdef USE_CUDNN
51
40
52
- CudnnRNNHandle::CudnnRNNHandle (const vector< Tensor> &inputs , const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks,
53
- const std::string Rnn_mode, const float Dropout, const bool bidirectional):
54
- RNNHandle(Input_size, Hidden_size, Num_stacks, Rnn_mode, Dropout, bidirectional) {
41
+ CudnnRNNHandle::CudnnRNNHandle (const Tensor &input , const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks,
42
+ const std::string Rnn_mode, const float Dropout, const bool bidirectional, const size_t Weight_size ):
43
+ RNNHandle(input, Input_size, Hidden_size, Num_stacks, Rnn_mode, Dropout, bidirectional, Weight_size ) {
55
44
56
- CHECK_GT (inputs.size (), 1u + has_cell_);
57
- size_t num_x = inputs.size () - has_cell_ - 1 ;
58
-
59
- DataType dtype = inputs.at (0 ).data_type ();
60
- if (rnn_desc_ != nullptr )
61
- CHECK_EQ (dtype_, GetCudnnDataType (dtype))
62
- << " Cannot change cudnn data type during training from " << dtype_
63
- << " to " << GetCudnnDataType (dtype);
64
- else
65
- dtype_ = GetCudnnDataType (dtype);
45
+ DataType dtype = input.data_type ();
46
+ dtype_ = GetCudnnDataType (dtype);
66
47
67
- UpdateStates (num_x, inputs);
48
+ UpdateIODescriptors (input);
49
+ ResetHiddenAndCellDescriptors ();
50
+ SetRNNDescriptor (input.device ());
51
+ UpdateSpaces (seq_length_, input.device ());
68
52
};
69
53
70
- void CudnnRNNHandle::UpdateStates (size_t num_x, const vector<Tensor> &inputs) {
71
- UpdateIODescriptors (num_x, inputs);
72
- size_t new_batch_size = inputs.at (0 ).shape (0 );
73
- if (batch_size_ != new_batch_size)
74
- ResetHiddenAndCellDescriptors (new_batch_size);
75
- if (rnn_desc_ == nullptr )
76
- SetRNNDescriptor (inputs.at (0 ).device ());
77
- UpdateSpaces (num_x, inputs.at (0 ).device ());
78
- batch_size_ = new_batch_size;
79
- seq_length_ = num_x;
54
+ CudnnRNNHandle::~CudnnRNNHandle () {
55
+ if (weight_desc_ != nullptr )
56
+ CUDNN_CHECK (cudnnDestroyFilterDescriptor (weight_desc_));
57
+ if (dropout_desc_ != nullptr )
58
+ CUDNN_CHECK (cudnnDestroyDropoutDescriptor (dropout_desc_));
59
+ if (rnn_desc_ != nullptr )
60
+ CUDNN_CHECK (cudnnDestroyRNNDescriptor (rnn_desc_));
61
+ if (hx_desc_ != nullptr )
62
+ CUDNN_CHECK (cudnnDestroyTensorDescriptor (hx_desc_));
63
+ if (hy_desc_ != nullptr )
64
+ CUDNN_CHECK (cudnnDestroyTensorDescriptor (hy_desc_));
65
+ if (cx_desc_ != nullptr )
66
+ CUDNN_CHECK (cudnnDestroyTensorDescriptor (cx_desc_));
67
+ if (cy_desc_ != nullptr )
68
+ CUDNN_CHECK (cudnnDestroyTensorDescriptor (cy_desc_));
69
+ if (dhx_desc_ != nullptr )
70
+ CUDNN_CHECK (cudnnDestroyTensorDescriptor (dhx_desc_));
71
+ if (dhy_desc_ != nullptr )
72
+ CUDNN_CHECK (cudnnDestroyTensorDescriptor (dhy_desc_));
73
+ if (dcx_desc_ != nullptr )
74
+ CUDNN_CHECK (cudnnDestroyTensorDescriptor (dcx_desc_));
75
+ if (dcy_desc_ != nullptr )
76
+ CUDNN_CHECK (cudnnDestroyTensorDescriptor (dcy_desc_));
77
+ DestroyIODescriptors ();
80
78
};
81
79
82
80
void CudnnRNNHandle::DestroyIODescriptors () {
83
81
if (x_descs_ != nullptr ) {
84
- for (size_t i = 0 ; i < max_length_ ; i++) {
82
+ for (size_t i = 0 ; i < seq_length_ ; i++) {
85
83
CUDNN_CHECK (cudnnDestroyTensorDescriptor (x_descs_[i]));
86
84
CUDNN_CHECK (cudnnDestroyTensorDescriptor (dx_descs_[i]));
87
85
}
88
86
delete [] x_descs_;
89
87
delete [] dx_descs_;
90
88
}
91
89
if (y_descs_ != nullptr ) {
92
- for (size_t i = 0 ; i < max_length_ ; i++) {
90
+ for (size_t i = 0 ; i < seq_length_ ; i++) {
93
91
CUDNN_CHECK (cudnnDestroyTensorDescriptor (y_descs_[i]));
94
92
CUDNN_CHECK (cudnnDestroyTensorDescriptor (dy_descs_[i]));
95
93
}
@@ -98,61 +96,60 @@ void CudnnRNNHandle::DestroyIODescriptors() {
98
96
}
99
97
};
100
98
101
- void CudnnRNNHandle::UpdateIODescriptors (size_t len, const vector<Tensor> &inputs) {
102
- bool reset = false ;
103
- if (max_length_ < len) {
104
- DestroyIODescriptors ();
105
- max_length_ = len;
106
- x_descs_ = new cudnnTensorDescriptor_t[len];
107
- dx_descs_ = new cudnnTensorDescriptor_t[len];
108
- y_descs_ = new cudnnTensorDescriptor_t[len];
109
- dy_descs_ = new cudnnTensorDescriptor_t[len];
110
- for (size_t i = 0 ; i < len; i++) {
99
+
100
+ void CudnnRNNHandle::UpdateIODescriptors (const Tensor &input) {
101
+ x_descs_ = new cudnnTensorDescriptor_t[seq_length_];
102
+ dx_descs_ = new cudnnTensorDescriptor_t[seq_length_];
103
+ y_descs_ = new cudnnTensorDescriptor_t[seq_length_];
104
+ dy_descs_ = new cudnnTensorDescriptor_t[seq_length_];
105
+ for (size_t i = 0 ; i < seq_length_; i++) {
111
106
CUDNN_CHECK (cudnnCreateTensorDescriptor (&x_descs_[i]));
112
107
CUDNN_CHECK (cudnnCreateTensorDescriptor (&dx_descs_[i]));
113
108
CUDNN_CHECK (cudnnCreateTensorDescriptor (&y_descs_[i]));
114
109
CUDNN_CHECK (cudnnCreateTensorDescriptor (&dy_descs_[i]));
115
110
}
116
- reset = true ;
117
- }
118
111
119
- for (size_t i = 0 ; i < len; i++) {
120
- CHECK_EQ (inputs[i].shape (1 ), input_size_);
121
- if (inputs[i].shape (0 ) != batch_size_ || reset) {
112
+ for (size_t i = 0 ; i < seq_length_; i++) {
113
+ CHECK_EQ (input.shape (2 ), input_size_);
122
114
int d[3 ] = {1 , 1 , 1 }, s[3 ] = {1 , 1 , 1 };
123
- d[0 ] = static_cast <int >(inputs[i]. shape ( 0 ) );
115
+ d[0 ] = static_cast <int >(batch_size_ );
124
116
CHECK_GT (d[0 ], 0 );
125
- d[1 ] = static_cast <int >(inputs[i]. shape ( 1 ) );
117
+ d[1 ] = static_cast <int >(input_size_ );
126
118
s[0 ] = d[1 ] * d[2 ];
127
119
s[1 ] = d[2 ];
128
120
CUDNN_CHECK (cudnnSetTensorNdDescriptor (x_descs_[i], dtype_, 3 , d, s));
129
121
CUDNN_CHECK (cudnnSetTensorNdDescriptor (dx_descs_[i], dtype_, 3 , d, s));
130
122
131
- d[0 ] = static_cast <int >(inputs[i]. shape ( 0 ) );
123
+ d[0 ] = static_cast <int >(batch_size_ );
132
124
d[1 ] = static_cast <int >(hidden_size_ * num_directions_);
133
125
s[0 ] = d[1 ] * d[2 ];
134
126
s[1 ] = d[2 ];
135
127
CUDNN_CHECK (cudnnSetTensorNdDescriptor (y_descs_[i], dtype_, 3 , d, s));
136
128
CUDNN_CHECK (cudnnSetTensorNdDescriptor (dy_descs_[i], dtype_, 3 , d, s));
137
129
}
138
- }
139
130
};
140
131
141
- void CudnnRNNHandle::ResetHiddenAndCellDescriptors (size_t batch_size ) {
142
- if (batch_size_ == 0 ) {
132
+ void CudnnRNNHandle::ResetHiddenAndCellDescriptors () {
133
+ if (cx_desc_ == nullptr )
143
134
CUDNN_CHECK (cudnnCreateTensorDescriptor (&cx_desc_));
135
+ if (dcx_desc_ == nullptr )
144
136
CUDNN_CHECK (cudnnCreateTensorDescriptor (&dcx_desc_));
137
+ if (cy_desc_ == nullptr )
145
138
CUDNN_CHECK (cudnnCreateTensorDescriptor (&cy_desc_));
139
+ if (dcy_desc_ == nullptr )
146
140
CUDNN_CHECK (cudnnCreateTensorDescriptor (&dcy_desc_));
141
+ if (hx_desc_ == nullptr )
147
142
CUDNN_CHECK (cudnnCreateTensorDescriptor (&hx_desc_));
143
+ if (dhx_desc_ == nullptr )
148
144
CUDNN_CHECK (cudnnCreateTensorDescriptor (&dhx_desc_));
145
+ if (hy_desc_ == nullptr )
149
146
CUDNN_CHECK (cudnnCreateTensorDescriptor (&hy_desc_));
147
+ if (dhy_desc_ == nullptr )
150
148
CUDNN_CHECK (cudnnCreateTensorDescriptor (&dhy_desc_));
151
- }
152
149
153
150
int dim[3 ] = {1 , 1 , 1 };
154
151
dim[0 ] = static_cast <int >(num_stacks_ * num_directions_);
155
- dim[1 ] = static_cast <int >(batch_size );
152
+ dim[1 ] = static_cast <int >(batch_size_ );
156
153
dim[2 ] = static_cast <int >(hidden_size_);
157
154
int stride[3 ] = {1 , 1 , 1 };
158
155
stride[0 ] = dim[1 ] * dim[2 ];
@@ -229,7 +226,7 @@ void CudnnRNNHandle::UpdateSpaces(size_t seq_length, shared_ptr<Device> dev) {
229
226
reserve_space_ = Tensor (Shape{count}, dev, kChar );
230
227
// reserve_space_.SetValue(0);
231
228
}
232
- }
229
+ };
233
230
234
231
Tensor MergeInputs (size_t num, const vector<Tensor> &in) {
235
232
if (num == 1 )
@@ -265,15 +262,14 @@ vector<Tensor> SplitOutput(size_t num, size_t dim,
265
262
266
263
std::vector<Tensor> GpuRNNForwardTraining (const CudnnRNNHandle &crh, const Tensor &input, const Tensor &hx, const Tensor &cx, const Tensor &W) {
267
264
DataType dtype = input.data_type ();
268
- auto dev = input.at ( 0 ). device ();
265
+ auto dev = input.device ();
269
266
270
267
271
268
Shape outshape{input.Size () * crh.hidden_size_ / crh.input_size_ * crh.num_directions_ };
272
269
Tensor output (outshape, dev, dtype);
273
270
// LOG(INFO) << "output size " << output.Size();
274
271
275
272
Shape state_shape{crh.num_stacks_ * crh.num_directions_ , crh.batch_size_ , crh.hidden_size_ };
276
- CHECK_EQ (hx.shape (), state_shape);
277
273
Tensor hy (state_shape, dev, dtype);
278
274
279
275
Tensor cy;
@@ -339,7 +335,6 @@ std::vector<Tensor> GpuRNNForwardInference(const CudnnRNNHandle &crh, const Tens
339
335
// LOG(INFO) << "output size " << output.Size();
340
336
341
337
Shape state_shape{crh.num_stacks_ * crh.num_directions_ , crh.batch_size_ , crh.hidden_size_ };
342
- CHECK_EQ (hx.shape (), state_shape);
343
338
Tensor hy (state_shape, dev, dtype);
344
339
345
340
Tensor cy;
@@ -389,7 +384,7 @@ std::vector<Tensor> GpuRNNForwardInference(const CudnnRNNHandle &crh, const Tens
389
384
return {output, hy, cy};
390
385
};
391
386
392
- std::vector<Tensor> GpuRNNBackward (const CudnnRNNHandle &crh, const vector< Tensor> &dY, const Tensor &dh , const Tensor &dc , const vector<Tensor> &cache) {
387
+ std::vector<Tensor> GpuRNNBackward (const CudnnRNNHandle &crh, const Tensor &dY, const Tensor &dhy , const Tensor &dcy , const std:: vector<Tensor> &cache) {
393
388
const Tensor x = cache[0 ];
394
389
const Tensor y = cache[1 ];
395
390
const Tensor hx = cache[2 ];
@@ -399,26 +394,22 @@ std::vector<Tensor> GpuRNNBackward(const CudnnRNNHandle &crh, const vector<Tenso
399
394
auto dev = y.device ();
400
395
auto dtype = y.data_type ();
401
396
402
-
403
397
CHECK_EQ (dY.Size (), y.Size ());
404
398
405
-
406
399
Shape xshape{y.Size () * crh.input_size_ / crh.hidden_size_ / crh.num_directions_ };
407
- CHECK_EQ (x.shape (), xshape)
408
400
Tensor dx (xshape, dev, dtype);
409
401
410
402
Tensor dw (W.shape (), dev, dtype);
411
403
412
404
Shape state_shape{crh.num_stacks_ * crh.num_directions_ , crh.batch_size_ , crh.hidden_size_ };
413
- CHECK_EQ (hx.shape (), state_shape)
414
405
Tensor dhx (state_shape, dev, dtype);
415
406
416
407
Tensor dcx;
417
408
if (crh.has_cell_ )
418
409
dcx.ResetLike (dhx);
419
410
420
411
dw.SetValue (0 .0f );
421
- Block *yb = y.block (), *dyb = dy .block (), *dhyb = dhy.block (),
412
+ Block *yb = y.block (), *dyb = dY .block (), *dhyb = dhy.block (),
422
413
*dcyb = dcy.block (), *xb = x.block (), *cxb = cx.block (),
423
414
*wb = W.block (), *dwb = dw.block (), *hxb = hx.block (),
424
415
*dxb = dx.block (), *dhxb = dhx.block (), *dcxb = dcx.block (),
0 commit comments