@@ -69,7 +69,7 @@ class MKLDNNBatchNormLayer : public MKLDNNLayer<Dtype>, public Layer<Dtype> {
69
69
, bwd_top_diff(), bwd_bottom_diff()
70
70
, BatchNormFwd_pd(), BatchNormBwd_pd()
71
71
, scaleshift_memory(), bwd_scaleshift_diff_memory()
72
- , output_memory(), bwd_bottom_diff_memory(), inplace_buffer_memory()
72
+ , output_memory(), bwd_bottom_diff_memory()
73
73
, input_primitive(), bwd_top_diff_primitive()
74
74
{
75
75
PERFORMANCE_EVENT_ID_RESET (perf_id_fw_);
@@ -95,12 +95,10 @@ class MKLDNNBatchNormLayer : public MKLDNNLayer<Dtype>, public Layer<Dtype> {
95
95
void InitBatchNormBwd (const vector<Blob<Dtype>*>& top,
96
96
const vector<bool >& propagate_down,
97
97
const vector<Blob<Dtype>*>& bottom);
98
- void InitBatchNormFwdPrimitive (int stats_batch_idx, bool inplace );
99
- void InitBatchNormBwdPrimitive (int stats_batch_idx, bool inplace );
98
+ void InitBatchNormFwdPrimitive (int stats_batch_idx);
99
+ void InitBatchNormBwdPrimitive (int stats_batch_idx);
100
100
template <bool diff> shared_ptr<memory> GetStatsBatchMemory (
101
101
shared_ptr<MKLDNNMemoryDescriptor<Dtype, diff> > mkldnn_data, int idx);
102
- template <bool diff> shared_ptr<memory> GetStatsBatchMemoryInplace (
103
- shared_ptr<MKLDNNMemoryDescriptor<Dtype, diff> > mkldnn_data, int idx, shared_ptr<memory > buffer_memory);
104
102
void InitStatsBatchVars (int batch_size);
105
103
shared_ptr<MKLDNNData<Dtype> > fwd_top_data, fwd_bottom_data;
106
104
shared_ptr<MKLDNNDiff<Dtype> > bwd_top_diff, bwd_bottom_diff;
@@ -112,8 +110,8 @@ class MKLDNNBatchNormLayer : public MKLDNNLayer<Dtype>, public Layer<Dtype> {
112
110
113
111
shared_ptr<memory> scaleshift_memory, bwd_scaleshift_diff_memory;
114
112
shared_ptr<memory> output_memory, bwd_bottom_diff_memory;
115
- shared_ptr<memory> inplace_buffer_memory;
116
- vector<shared_ptr<memory> > input_stats, output_stats, top_diff_stats, bottom_diff_stats, input_inplace_buffer ;
113
+
114
+ vector<shared_ptr<memory> > input_stats, output_stats, top_diff_stats, bottom_diff_stats;
117
115
118
116
shared_ptr<primitive> input_primitive, bwd_top_diff_primitive;
119
117
@@ -124,6 +122,7 @@ class MKLDNNBatchNormLayer : public MKLDNNLayer<Dtype>, public Layer<Dtype> {
124
122
int stats_batch_size_;
125
123
shared_ptr<Blob<Dtype> > scaleshift_blob_;
126
124
shared_ptr<Blob<Dtype> > scaleshift_acc_;
125
+ Blob<Dtype> inplace_buffer;
127
126
128
127
PERFORMANCE_EVENT_ID_DECL (perf_id_fw_);
129
128
PERFORMANCE_EVENT_ID_DECL (perf_id_bw_);
@@ -224,7 +223,7 @@ class MKLDNNInnerProductLayer : public MKLDNNLayer<Dtype> , public InnerProductL
224
223
, bwdd_top_diff_primitive, bwdd_weights_data_primitive
225
224
, bwdw_top_diff_primitive, bwdw_bottom_data_primitive;
226
225
int32_t w_, h_;
227
-
226
+
228
227
/* In case of (iter_size > 1) we need additional buffers */
229
228
shared_ptr<MKLDNNDiff<Dtype> > bwdw_weights_diff_iter, bwdw_bias_diff_iter;
230
229
shared_ptr<memory> bwdw_weights_diff_memory_iter, bwdw_bias_diff_memory_iter;
@@ -322,13 +321,14 @@ class MKLDNNPoolingLayer : public MKLDNNLayer<Dtype>, public Layer<Dtype> {
322
321
,const vector<Blob<Dtype>*>& bottom);
323
322
virtual void Backward_gpu (const vector<Blob<Dtype>*>& top, const vector<bool >& propagate_down
324
323
,const vector<Blob<Dtype>*>& bottom);
324
+ virtual void compute_output_shape (const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top);
325
325
326
326
private:
327
327
void InitPoolingFwd (const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top);
328
328
void InitPoolingBwd (const vector<Blob<Dtype>*>& bottom
329
329
, const vector<bool >& propagate_down
330
330
, const vector<Blob<Dtype>*>& top);
331
-
331
+
332
332
shared_ptr<MKLDNNData<Dtype>> fwd_bottom_data, fwd_top_data;
333
333
shared_ptr<MKLDNNDiff<Dtype>> bwd_top_diff, bwd_bottom_diff;
334
334
shared_ptr<pooling_forward::primitive_desc> poolingFwd_pd;
@@ -408,7 +408,7 @@ class MKLDNNConcatLayer : public MKLDNNLayer<Dtype> , public Layer<Dtype> {
408
408
: MKLDNNLayer<Dtype>(), Layer<Dtype>(param),
409
409
concatFwd_pd(), fwd_output_memory(),
410
410
bwd_reorder_input_memory(), bwd_reorder_output_memory(),
411
- fwd_top_data(), fwd_bottom_data(), split_channels () {
411
+ fwd_top_data(), fwd_bottom_data(), split_dims () {
412
412
PERFORMANCE_EVENT_ID_RESET (perf_id_fw_);
413
413
PERFORMANCE_EVENT_ID_RESET (perf_id_bw_);
414
414
}
@@ -440,7 +440,7 @@ class MKLDNNConcatLayer : public MKLDNNLayer<Dtype> , public Layer<Dtype> {
440
440
shared_ptr<MKLDNNDiff<Dtype> > bwd_top_diff;
441
441
vector<shared_ptr<MKLDNNDiff<Dtype> > > bwd_bottom_diff;
442
442
vector<MKLDNNPrimitive<Dtype> > reorders;
443
- vector<int > split_channels ;
443
+ vector<int > split_dims ;
444
444
445
445
int32_t num_, width_, height_, channels_, num_concats_;
446
446
int concat_dimension;
0 commit comments