Skip to content
This repository was archived by the owner on Aug 5, 2022. It is now read-only.

Commit cef51e9

Browse files
Yu, ChongGerrit Code Review
authored andcommitted
Merge "Relu optimization related to ICL-84."
2 parents 61d99ed + 5bc1e99 commit cef51e9

File tree

2 files changed

+42
-1
lines changed

2 files changed

+42
-1
lines changed

src/caffe/layers/mkldnn_relu_layer.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,10 +188,31 @@ void MKLDNNReLULayer<Dtype>::InitReLUBwd(const vector<Blob<Dtype>*>& top
188188
if (top_diff_is_prv) {
189189
shared_ptr<MKLDNNMemoryDescriptor<Dtype, /* is_diff */ true> > mem_descr
190190
= get_mkldnn_prv_descriptor<Dtype, /* is_diff */ true>(top[0]);
191+
#ifdef DEBUG
192+
memory::format bwd_prv_top_diff_mfmt = static_cast<memory::format>(mem_descr->prv_memory_pd()->desc().data.format);
193+
LOG(INFO) << "MKLDNNReLULayer<Dtype>::InitReLUBwd: memory format of prv top diff is: " << bwd_prv_top_diff_mfmt;
194+
#endif
191195
top_diff_md.reset(new memory::desc(mem_descr->prv_memory_pd()->desc()));
192196
usr_diff_mpd = mem_descr->usr_memory_pd();
193197
prv_diff_mpd = mem_descr->prv_memory_pd();
194198
} else {
199+
bool bottom_data_is_prv = (const_cast<Dtype*>(bottom[0]->prv_data()) != NULL);
200+
if (bottom_data_is_prv) {
201+
shared_ptr<MKLDNNMemoryDescriptor<Dtype, false> > mem_descr
202+
= get_mkldnn_prv_descriptor<Dtype, false>(bottom[0]);
203+
#ifdef DEBUG
204+
memory::format fwd_prv_bottom_data_mfmt = static_cast<memory::format>(mem_descr->prv_memory_pd()->desc().data.format);
205+
LOG(INFO) << "MKLDNNReLULayer<Dtype>::InitReLUBwd: memory format of prv bottom data is: " << fwd_prv_bottom_data_mfmt;
206+
LOG(INFO) << "MKLDNNReLULayer<Dtype>::InitReLUBwd: Reorder the usr top diff to the format of prv bottom data! (Performance consideration)";
207+
#endif
208+
prv_diff_mpd = mem_descr->prv_memory_pd();
209+
//top[0]->prv_data() is empty, however top[0]->get_prv_diff_descriptor() has value.
210+
//Find root cause in the mkldnn_memory: create_output_memory() and sync_before_write() functions.
211+
//But that a major fix, will lead the nan in the AlexNet training.
212+
//So need investigation further, however, this will fix ICL-84.
213+
top[0]->set_prv_diff_descriptor(NULL);
214+
}
215+
195216
top_diff_md.reset(new memory::desc({{n, ic, ih, iw}}, mpcsn, memory::format::nchw));
196217
usr_diff_mpd.reset(new memory::primitive_desc(*top_diff_md, cpu_engine));
197218
}

src/caffe/mkldnn_memory.cpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,16 @@ void MKLDNNMemoryDescriptor<Dtype, is_diff>::sync_before_write(bool inplace)
374374
this->_blob->set_prv_data_descriptor(this->get_shared_ptr(), this->conversion_needed() ? false : true);
375375
}
376376
}
377+
//Fix me: this->conversion_needed() == false means diff/data is in the CPU, no need to set the prv_diff/data_descriptor
378+
/*
379+
if ((!inplace) && (this->conversion_needed())) {
380+
if (is_diff) {
381+
this->_blob->set_prv_diff_descriptor(this->get_shared_ptr(), false);
382+
} else {
383+
this->_blob->set_prv_data_descriptor(this->get_shared_ptr(), false);
384+
}
385+
}
386+
*/
377387
}
378388

379389
template <typename Dtype, bool is_diff>
@@ -420,7 +430,7 @@ shared_ptr<primitive> MKLDNNMemoryDescriptor<Dtype, is_diff>::create_input(bool
420430
template <typename Dtype, bool is_diff>
421431
shared_ptr<memory> MKLDNNMemoryDescriptor<Dtype, is_diff>::create_output_memory(bool inplace)
422432
{
423-
// TODO: need to iptimize code
433+
// TODO: need to optimize code
424434
shared_ptr<memory> omem = create_output_memory(this->_blob);
425435
if(!inplace) {
426436
if(is_diff) {
@@ -429,6 +439,16 @@ shared_ptr<memory> MKLDNNMemoryDescriptor<Dtype, is_diff>::create_output_memory(
429439
this->_blob->set_prv_data_descriptor(this->get_shared_ptr(), this->conversion_needed() ? false : true);
430440
}
431441
}
442+
/*
443+
//Fix me: this->conversion_needed() == false means diff/data is in the CPU, no need to set the prv_diff/data_descriptor
444+
if ((!inplace) && (this->conversion_needed())) {
445+
if (is_diff) {
446+
this->_blob->set_prv_diff_descriptor(this->get_shared_ptr(), false);
447+
} else {
448+
this->_blob->set_prv_data_descriptor(this->get_shared_ptr(), false);
449+
}
450+
}
451+
*/
432452
return omem;
433453
}
434454

0 commit comments

Comments
 (0)