Skip to content
This repository was archived by the owner on Aug 5, 2022. It is now read-only.

Commit ef4b23d

Browse files
committed
Fix SSD/ResNet50 winograd training/test accuracy/assertion issues.
It's mainly due to mkldnn reorder implementation, which only support nchw input format to wino_fmt output reorder. Original framework code path may create wino_fmt to nchw and nchw16i16o to wino_fmt reorder primitive, but those are not supported by mkldnn and will trigger assertion.
1 parent 740351e commit ef4b23d

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

src/caffe/layers/mkldnn_convolution_layer.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,7 @@ void MKLDNNConvolutionLayer<Dtype>::InitConvolutionFwd(const vector<Blob<Dtype>*
495495
fwd_top_data->name = "fwd_top_data @ " + this->layer_param_.name();
496496
fwd_top_data_memory = fwd_top_data->create_output_memory();
497497

498-
bool is_wino = conv_algorithm == algorithm::convolution_winograd ? true : false;
498+
bool is_wino = (prv_fwd_weights_data_memory_pd->desc().data.format == memory::format::wino_fmt);
499499
if (fwd_weights_data == NULL) {
500500
if (this->need_quantize_){
501501
int count = 1; //single channel

src/caffe/mkldnn_memory.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,13 @@ template <typename Dtype, bool is_diff>
153153
#ifdef DEBUG
154154
LOG(INFO) << "Formats of blob-prv-memory-pd and this-prv-memory-pd are not equal !";
155155
#endif
156-
this->set_extprv_memory_pd(blob_prv_mkldnn_mem_descr->prv_memory_pd(), scale, blob_prv_mkldnn_mem_descr->get_scale(), blob_prv_mkldnn_mem_descr->get_sum());
156+
if (!is_wino)
157+
this->set_extprv_memory_pd(blob_prv_mkldnn_mem_descr->prv_memory_pd(), scale, blob_prv_mkldnn_mem_descr->get_scale(), blob_prv_mkldnn_mem_descr->get_sum());
158+
else
159+
// If the blob's prv is using wino_fmt, it is only able to accept nchw to wino reorder due to mkldnn implementation. Other input formats, such as nchw16i16o
160+
// to wino reorder, are not supported by mkldnn.
161+
// Therefore we have to force the blob state from prv to cpu to get nchw input format at first.
162+
if (is_diff) blob->mutable_cpu_diff(); else blob->mutable_cpu_data();
157163
}
158164
}
159165
}

0 commit comments

Comments
 (0)