diff --git a/src/gpu/nvidia/cudnn_eltwise_impl.hpp b/src/gpu/nvidia/cudnn_eltwise_impl.hpp index aba08f055fb..16134bbabc9 100644 --- a/src/gpu/nvidia/cudnn_eltwise_impl.hpp +++ b/src/gpu/nvidia/cudnn_eltwise_impl.hpp @@ -92,6 +92,13 @@ struct cudnn_eltwise_fwd_impl_t : public cudnn_eltwise_impl_base_t { if (pd->ndims() > CUDNN_DIM_MAX) { return status::invalid_arguments; } ndims = pd->ndims() < 4 ? 4 : pd->ndims(); + for (int i = 0; i < ndims; ++i) { + if (pd->src_md()->padded_dims[i] + > std::numeric_limits::max()) { + return status::unimplemented; + } + } + // Obtain source and destination dimensions, strides and datatype convert_dims(pd->src_md()->padded_dims, dims_, pd->ndims()); convert_dims(pd->src_md()->format_desc.blocking.strides, strides_, @@ -139,6 +146,12 @@ struct cudnn_eltwise_bwd_impl_t : public cudnn_eltwise_impl_base_t { if (pd->ndims() > CUDNN_DIM_MAX) { return status::invalid_arguments; } ndims = pd->ndims() < 4 ? 4 : pd->ndims(); + for (int i = 0; i < ndims; ++i) { + if (pd->src_md()->padded_dims[i] + > std::numeric_limits::max()) { + return status::unimplemented; + } + } // Obtain dimension and strides for the backward eltwise operation convert_dims(pd->src_md()->padded_dims, dims_, pd->ndims());