Skip to content

Commit

Permalink
gpu: nvidia: fix relu type cast for dimension (#2519)
Browse files Browse the repository at this point in the history
  • Loading branch information
s-Nick authored Jan 28, 2025
1 parent 79ae0fc commit 16728ec
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions src/gpu/nvidia/cudnn_eltwise_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,13 @@ struct cudnn_eltwise_fwd_impl_t : public cudnn_eltwise_impl_base_t {
if (pd->ndims() > CUDNN_DIM_MAX) { return status::invalid_arguments; }
ndims = pd->ndims() < 4 ? 4 : pd->ndims();

for (int i = 0; i < ndims; ++i) {
if (pd->src_md()->padded_dims[i]
> std::numeric_limits<int>::max()) {
return status::unimplemented;
}
}

// Obtain source and destination dimensions, strides and datatype
convert_dims(pd->src_md()->padded_dims, dims_, pd->ndims());
convert_dims(pd->src_md()->format_desc.blocking.strides, strides_,
Expand Down Expand Up @@ -139,6 +146,12 @@ struct cudnn_eltwise_bwd_impl_t : public cudnn_eltwise_impl_base_t {
if (pd->ndims() > CUDNN_DIM_MAX) { return status::invalid_arguments; }
ndims = pd->ndims() < 4 ? 4 : pd->ndims();

for (int i = 0; i < ndims; ++i) {
if (pd->src_md()->padded_dims[i]
> std::numeric_limits<int>::max()) {
return status::unimplemented;
}
}
// Obtain dimension and strides for the backward eltwise operation
convert_dims(pd->src_md()->padded_dims, dims_, pd->ndims());

Expand Down

0 comments on commit 16728ec

Please sign in to comment.