From 9a37a19498bd5c5f3601dc956ae39824545216d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tadej=20Ciglari=C4=8D?= Date: Wed, 16 Oct 2024 12:55:10 +0200 Subject: [PATCH] generic: sycl: softmax: bugfix checks --- src/gpu/generic/sycl/ref_softmax.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gpu/generic/sycl/ref_softmax.hpp b/src/gpu/generic/sycl/ref_softmax.hpp index f1447334061..43f30e54a35 100644 --- a/src/gpu/generic/sycl/ref_softmax.hpp +++ b/src/gpu/generic/sycl/ref_softmax.hpp @@ -48,7 +48,7 @@ struct ref_sycl_softmax_fwd_t : public gpu::generic::sycl::primitive_t { && sycl_post_ops_t::post_ops_ok(attr(), true, false) && set_default_formats() == status::success && attr_.set_default_formats(dst_md()) == status::success - && check_formats(diff_src_md(), diff_dst_md()) + && check_formats(src_md(), dst_md()) && md_dims_in_range(src_md()); if (!ok) return status::unimplemented; @@ -111,7 +111,7 @@ struct ref_sycl_softmax_bwd_t : public gpu::generic::sycl::primitive_t { && dst_md()->data_type == diff_dst_md()->data_type && attr()->has_default_values() && set_default_formats() == status::success - && check_formats(src_md(), dst_md()) + && check_formats(diff_src_md(), diff_dst_md()) && md_dims_in_range(diff_dst_md()); if (!ok) return status::unimplemented;