From 32be96c327946fe620931aea8d8b2b05b8302765 Mon Sep 17 00:00:00 2001 From: "Simonov, Alexander" Date: Tue, 11 Feb 2025 08:02:18 -0800 Subject: [PATCH] cpu: x64: pooling: use shared f8 emulators --- src/cpu/x64/jit_uni_pool_kernel.cpp | 81 ++++++++++++++--------------- src/cpu/x64/jit_uni_pool_kernel.hpp | 4 +- src/cpu/x64/utils/jit_io_helper.cpp | 10 +++- src/cpu/x64/utils/jit_io_helper.hpp | 5 +- 4 files changed, 52 insertions(+), 48 deletions(-) diff --git a/src/cpu/x64/jit_uni_pool_kernel.cpp b/src/cpu/x64/jit_uni_pool_kernel.cpp index dca6c01a364..9127c10f28a 100644 --- a/src/cpu/x64/jit_uni_pool_kernel.cpp +++ b/src/cpu/x64/jit_uni_pool_kernel.cpp @@ -66,47 +66,9 @@ jit_uni_pool_kernel::jit_uni_pool_kernel( } } - if (use_fp8_emulation() || has_f8_e5m2_binary_postops - || has_f8_e4m3_binary_postops) { - if (utils::one_of(data_type::f8_e5m2, ajpp.src_dt, ajpp.dst_dt) - || has_f8_e5m2_binary_postops) - f8_e5m2_emu_ = utils::make_unique(this, - fp8_emu_reserv_1, fp8_emu_reserv_2, fp8_emu_reserv_3, - fp8_tmp_mask, fp8_emu_reg64); - if (utils::one_of(data_type::f8_e4m3, ajpp.src_dt, ajpp.dst_dt) - || has_f8_e4m3_binary_postops) - f8_e4m3_emu_ = utils::make_unique(this, - fp8_emu_reserv_1, fp8_emu_reserv_2, fp8_emu_reserv_3, - fp8_emu_reserv_4, fp8_emu_reserv_5, fp8_emu_reg64); - } - const auto tail_size = isa == sse41 ? jpp.c_tail % sse41_single_block_size : jpp.c_tail; - if (jpp.with_postops) { - static constexpr bool preserve_gpr = true; - static constexpr bool preserve_vmm = true; - static constexpr bool use_exact_tail_scalar_bcast = false; - - const binary_injector::rhs_arg_static_params_t rhs_sp { - static_cast(this->xmm4.getIdx()), this->r14, - this->r15, this->r13, preserve_gpr, preserve_vmm, - GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), - memory_desc_wrapper(jpp.tag_kind == jit_memory_tag_kind_t::ncsp - ? jpp.tmp_md - : *dst_md), - static_cast(tail_size), k_c_tail_mask, - use_exact_tail_scalar_bcast}; - - const binary_injector::static_params_t bsp {reg_param, - get_supported_bcast_strategies(), rhs_sp, f8_e5m2_emu_.get(), - f8_e4m3_emu_.get()}; - - postops_injector_ - = utils::make_unique>( - this, jpp.post_ops, bsp); - } - io::io_tail_conf_t io_tail_conf(jpp.c_block, tail_size, k_c_tail_mask.getIdx(), vmm_c_tail_mask.getIdx(), tmp_gpr); @@ -131,9 +93,48 @@ jit_uni_pool_kernel::jit_uni_pool_kernel( // 4 bytes. jit_io_helper_t is not used for processing indices of type u8. if (jpp.ind_dt == data_type::s32) dtypes.insert(data_type::f32); if (jpp.needs_f32_accum_for_bf16) dtypes.insert(data_type::f32); + // This is required to guarantee that f8 emulators are created by io injector. + if (has_f8_e5m2_binary_postops) dtypes.insert(data_type::f8_e5m2); + if (has_f8_e4m3_binary_postops) dtypes.insert(data_type::f8_e4m3); io_ = io_mdt_helper(this, jpp.isa, dtypes, {}, io_tail_conf, io_bf16_conf, {}, utils::nullopt, io_fp8_conf); + + if (jpp.with_postops) { + static constexpr bool preserve_gpr = true; + static constexpr bool preserve_vmm = true; + static constexpr bool use_exact_tail_scalar_bcast = false; + + const binary_injector::rhs_arg_static_params_t rhs_sp { + static_cast(this->xmm4.getIdx()), this->r14, + this->r15, this->r13, preserve_gpr, preserve_vmm, + GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), + memory_desc_wrapper(jpp.tag_kind == jit_memory_tag_kind_t::ncsp + ? jpp.tmp_md + : *dst_md), + static_cast(tail_size), k_c_tail_mask, + use_exact_tail_scalar_bcast}; + + std::shared_ptr f8_e5m2_emu; + auto f8_e5m2_io_helper = io_.at(data_type::f8_e5m2); + if (f8_e5m2_io_helper) + f8_e5m2_emu = std::dynamic_pointer_cast( + f8_e5m2_io_helper->get_fp8_emu()); + + std::shared_ptr f8_e4m3_emu; + auto f8_e4m3_io_helper = io_.at(data_type::f8_e4m3); + if (f8_e4m3_io_helper) + f8_e4m3_emu = std::dynamic_pointer_cast( + f8_e4m3_io_helper->get_fp8_emu()); + + const binary_injector::static_params_t bsp {reg_param, + get_supported_bcast_strategies(), rhs_sp, f8_e5m2_emu.get(), + f8_e4m3_emu.get()}; + + postops_injector_ + = utils::make_unique>( + this, jpp.post_ops, bsp); + } } static status_t set_binary_postops_formats( @@ -1525,11 +1526,9 @@ void jit_uni_pool_kernel::generate() { this->postamble(); + io_.prepare_table_fp8(); if (jpp.with_eltwise && postops_injector_) postops_injector_->prepare_table(/* generate = */ true); - if (f8_e5m2_emu_) f8_e5m2_emu_->prepare_table(); - if (f8_e4m3_emu_) f8_e4m3_emu_->prepare_table(); - io_.prepare_table_fp8(); } template struct jit_uni_pool_kernel; diff --git a/src/cpu/x64/jit_uni_pool_kernel.hpp b/src/cpu/x64/jit_uni_pool_kernel.hpp index 6f239882d04..482ab375be0 100644 --- a/src/cpu/x64/jit_uni_pool_kernel.hpp +++ b/src/cpu/x64/jit_uni_pool_kernel.hpp @@ -254,11 +254,9 @@ struct jit_uni_pool_kernel : public jit_generator { return jpp.is_fp8 && is_superset(isa, avx512_core_fp16); } - std::unique_ptr f8_e5m2_emu_; - std::unique_ptr f8_e4m3_emu_; + io::jit_io_multi_dt_helper_t io_; std::unique_ptr> postops_injector_; - io::jit_io_multi_dt_helper_t io_; }; } // namespace x64 diff --git a/src/cpu/x64/utils/jit_io_helper.cpp b/src/cpu/x64/utils/jit_io_helper.cpp index 63c1a4d86cd..9a6c50c0d20 100644 --- a/src/cpu/x64/utils/jit_io_helper.cpp +++ b/src/cpu/x64/utils/jit_io_helper.cpp @@ -148,14 +148,14 @@ jit_io_helper_t::jit_io_helper_t(jit_generator *host, const cpu_isa_t &isa, assert(fp8_conf.has_value() && "Config for fp8 emulation is not set."); switch (data_type_) { case data_type::f8_e5m2: - fp8_emu_ = utils::make_unique(host_, + fp8_emu_ = std::make_shared(host_, fp8_conf->fp8_emu_reserv_1_, fp8_conf->fp8_emu_reserv_2_, fp8_conf->fp8_emu_reserv_3_, fp8_conf->kmask_aux_, fp8_conf->reg_tmp_); break; case data_type::f8_e4m3: - fp8_emu_ = utils::make_unique(host_, + fp8_emu_ = std::make_shared(host_, fp8_conf->fp8_emu_reserv_1_, fp8_conf->fp8_emu_reserv_2_, fp8_conf->fp8_emu_reserv_3_, @@ -743,6 +743,12 @@ void jit_io_helper_t::merge_interleaved_to_plain( host_->vperm2i128(ymm_odd, ymm_aux0, ymm_aux1, 0x31); } +template +std::shared_ptr +jit_io_helper_t::get_fp8_emu() const { + return fp8_emu_; +} + template void jit_io_helper_t::store(const Vmm &src_raw_vmm, const Xbyak::Address &dst_raw_addr, const bool tail) { diff --git a/src/cpu/x64/utils/jit_io_helper.hpp b/src/cpu/x64/utils/jit_io_helper.hpp index 7b03598bd12..0f1cb6eddb9 100644 --- a/src/cpu/x64/utils/jit_io_helper.hpp +++ b/src/cpu/x64/utils/jit_io_helper.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2021-2024 Intel Corporation +* Copyright 2021-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -199,6 +199,7 @@ class jit_io_helper_t { const Vmm &dst_even_vmm, const Vmm &dst_odd_vmm); void merge_interleaved_to_plain( const Vmm &vmm_even, const Vmm &vmm_odd, const Vmm &vmm_aux0); + std::shared_ptr get_fp8_emu() const; private: bool is_data_type_supported(const data_type_t dt); @@ -242,7 +243,7 @@ class jit_io_helper_t { const bool f16_supported_; const bool fp8_supported_; std::unique_ptr bf16_emu_; - std::unique_ptr fp8_emu_; + std::shared_ptr fp8_emu_; const io_conf_t io_conf_; const utils::optional_t tail_conf_; const utils::optional_t bf16_conf_;