Skip to content

Commit

Permalink
generic: gpu: convolution/deconvolution/softmax: add missing checks
Browse files Browse the repository at this point in the history
  • Loading branch information
t4c1 authored and sgeor255 committed Oct 28, 2024
1 parent 301eeff commit b837aa1
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 11 deletions.
7 changes: 1 addition & 6 deletions src/gpu/generic/sycl/ref_convolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,6 @@ struct ref_convolution_bwd_weights_t : public gpu::generic::sycl::primitive_t {

status_t init(impl::engine_t *engine) {
using namespace data_type;
using sm = primitive_attr_t::skip_mask_t;

const memory_desc_wrapper data_d(src_md());
const memory_desc_wrapper diff_weights_d(diff_weights_md());
Expand All @@ -210,11 +209,7 @@ struct ref_convolution_bwd_weights_t : public gpu::generic::sycl::primitive_t {
data_d, diff_weights_d, diff_dst_d)
&& check_convolution_formats(
data_d, diff_weights_d, diff_dst_d)
&& attr()->has_default_values(sm::scales_runtime
| sm::zero_points_runtime | sm::sum_dt)
&& IMPLICATION(!attr()->scales_.has_default_values(),
attr_scales_ok()
&& check_convolution_scales_types(attr()))
&& attr()->has_default_values()
&& set_default_alg_kind(alg_kind::convolution_direct);
if (!ok) return status::unimplemented;

Expand Down
5 changes: 1 addition & 4 deletions src/gpu/generic/sycl/ref_deconvolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ struct ref_deconvolution_bwd_weights_t

status_t init(impl::engine_t *engine) {
using namespace data_type;
using sm = primitive_attr_t::skip_mask_t;

const memory_desc_wrapper data_d(src_md());
const memory_desc_wrapper diff_weights_d(diff_weights_md());
Expand All @@ -57,9 +56,7 @@ struct ref_deconvolution_bwd_weights_t
data_d, diff_weights_d, diff_dst_d)
&& check_convolution_formats(
data_d, diff_weights_d, diff_dst_d)
&& attr()->has_default_values(sm::scales_runtime
| sm::zero_points_runtime | sm::post_ops
| sm::sum_dt)
&& attr()->has_default_values()
&& desc()->alg_kind == alg_kind::deconvolution_direct;
if (!ok) return status::unimplemented;

Expand Down
2 changes: 1 addition & 1 deletion src/gpu/generic/sycl/ref_softmax.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ struct ref_sycl_softmax_fwd_t : public gpu::generic::sycl::primitive_t {
&& sycl_post_ops_t::post_ops_ok(attr(), true, false)
&& set_default_formats() == status::success
&& attr_.set_default_formats(dst_md()) == status::success
&& check_formats(src_md(), dst_md())
&& check_formats(diff_src_md(), diff_dst_md())
&& md_dims_in_range(src_md());

if (!ok) return status::unimplemented;
Expand Down

0 comments on commit b837aa1

Please sign in to comment.