Skip to content

Commit a6a303a

Browse files
committed
cpu: rnn: enable matmul impl back for SYCL CPU
Regular, raw CPU pointers can only be used with memory_t objects created for the classic CPU engine. Use service engine to create such objects.
1 parent d80337d commit a6a303a

File tree

2 files changed

+12
-9
lines changed

2 files changed

+12
-9
lines changed

src/cpu/rnn/ref_rnn.cpp

+11-5
Original file line numberDiff line numberDiff line change
@@ -830,20 +830,26 @@ template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type,
830830
rnn_matmul_sig((_ref_rnn_common_t<aprop, src_type, weights_type,
831831
acc_type>::execute_matmul)) {
832832

833-
engine_t *engine = ctx.stream()->engine();
833+
// Service engine is just a global classic CPU engine that is used
834+
// when it's required to create memory_t objects for classic CPU
835+
// engine regardless of the CPU runtime. For example, SYCL CPU engine
836+
// cannot be used to create such objects.
837+
engine_t *service_engine = get_service_engine();
834838
constexpr auto mem_flag = memory_flags_t::use_runtime_ptr;
835839

840+
// a_, b_ and c_ are regular, raw CPU pointers that can only be used with
841+
// memory_t objects created for the classic CPU engine.
836842
std::unique_ptr<memory_t, memory_deleter_t> src_mem;
837843
CHECK(safe_ptr_assign(src_mem,
838-
new memory_t(engine, matmul_prim->pd()->src_md(), mem_flag,
844+
new memory_t(service_engine, matmul_prim->pd()->src_md(), mem_flag,
839845
(void *)(a_))));
840846
std::unique_ptr<memory_t, memory_deleter_t> wei_mem;
841847
CHECK(safe_ptr_assign(wei_mem,
842-
new memory_t(engine, matmul_prim->pd()->weights_md(), mem_flag,
843-
(void *)(b_))));
848+
new memory_t(service_engine, matmul_prim->pd()->weights_md(),
849+
mem_flag, (void *)(b_))));
844850
std::unique_ptr<memory_t, memory_deleter_t> dst_mem;
845851
CHECK(safe_ptr_assign(dst_mem,
846-
new memory_t(engine, matmul_prim->pd()->dst_md(), mem_flag,
852+
new memory_t(service_engine, matmul_prim->pd()->dst_md(), mem_flag,
847853
(void *)(c_))));
848854

849855
exec_args_t matmul_args;

src/cpu/rnn/rnn_utils.hpp

+1-4
Original file line numberDiff line numberDiff line change
@@ -858,17 +858,14 @@ bool init_conf(rnn_conf_t &rnn, const rnn_desc_t &rd,
858858
rnn.use_matmul = !rnn.is_brgemm && rnn.is_fwd // TODO: Enable BWD
859859
// TODO: Below checks are for legacy and a performance study is
860860
// required to avoid regressions.
861-
// TODO: using matmul is disabled for SYCL runtime for now.
862-
// Enable it after memory handles issue fix
863861
#if DNNL_X64
864862
&& IMPLICATION(
865863
rnn.is_cell_dt_bf16(), !x64::mayiuse(x64::avx512_core))
866864
&& IMPLICATION(rnn.is_cell_dt_f32() || rnn.is_cell_dt_int8(),
867865
x64::mayiuse(x64::avx2)
868866
&& utils::one_of(rd.cell_kind,
869867
alg_kind::vanilla_gru,
870-
alg_kind::vanilla_augru))
871-
&& (DNNL_CPU_RUNTIME != DNNL_RUNTIME_SYCL);
868+
alg_kind::vanilla_augru));
872869
#else
873870
&& !rnn.is_cell_dt_f32() && !rnn.is_cell_dt_int8();
874871
#endif

0 commit comments

Comments
 (0)