Skip to content

Commit

Permalink
cpu: x64: pooling: use shared f8 emulators
Browse files Browse the repository at this point in the history
  • Loading branch information
asimonov1 committed Feb 11, 2025
1 parent fbf38e7 commit 32be96c
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 48 deletions.
81 changes: 40 additions & 41 deletions src/cpu/x64/jit_uni_pool_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,47 +66,9 @@ jit_uni_pool_kernel<isa>::jit_uni_pool_kernel(
}
}

if (use_fp8_emulation() || has_f8_e5m2_binary_postops
|| has_f8_e4m3_binary_postops) {
if (utils::one_of(data_type::f8_e5m2, ajpp.src_dt, ajpp.dst_dt)
|| has_f8_e5m2_binary_postops)
f8_e5m2_emu_ = utils::make_unique<fp8_emulation_e5m2_t>(this,
fp8_emu_reserv_1, fp8_emu_reserv_2, fp8_emu_reserv_3,
fp8_tmp_mask, fp8_emu_reg64);
if (utils::one_of(data_type::f8_e4m3, ajpp.src_dt, ajpp.dst_dt)
|| has_f8_e4m3_binary_postops)
f8_e4m3_emu_ = utils::make_unique<fp8_emulation_e4m3_t>(this,
fp8_emu_reserv_1, fp8_emu_reserv_2, fp8_emu_reserv_3,
fp8_emu_reserv_4, fp8_emu_reserv_5, fp8_emu_reg64);
}

const auto tail_size
= isa == sse41 ? jpp.c_tail % sse41_single_block_size : jpp.c_tail;

if (jpp.with_postops) {
static constexpr bool preserve_gpr = true;
static constexpr bool preserve_vmm = true;
static constexpr bool use_exact_tail_scalar_bcast = false;

const binary_injector::rhs_arg_static_params_t rhs_sp {
static_cast<std::size_t>(this->xmm4.getIdx()), this->r14,
this->r15, this->r13, preserve_gpr, preserve_vmm,
GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig),
memory_desc_wrapper(jpp.tag_kind == jit_memory_tag_kind_t::ncsp
? jpp.tmp_md
: *dst_md),
static_cast<size_t>(tail_size), k_c_tail_mask,
use_exact_tail_scalar_bcast};

const binary_injector::static_params_t bsp {reg_param,
get_supported_bcast_strategies(), rhs_sp, f8_e5m2_emu_.get(),
f8_e4m3_emu_.get()};

postops_injector_
= utils::make_unique<injector::jit_uni_postops_injector_t<isa>>(
this, jpp.post_ops, bsp);
}

io::io_tail_conf_t io_tail_conf(jpp.c_block, tail_size,
k_c_tail_mask.getIdx(), vmm_c_tail_mask.getIdx(), tmp_gpr);

Expand All @@ -131,9 +93,48 @@ jit_uni_pool_kernel<isa>::jit_uni_pool_kernel(
// 4 bytes. jit_io_helper_t is not used for processing indices of type u8.
if (jpp.ind_dt == data_type::s32) dtypes.insert(data_type::f32);
if (jpp.needs_f32_accum_for_bf16) dtypes.insert(data_type::f32);
// This is required to guarantee that f8 emulators are created by io injector.
if (has_f8_e5m2_binary_postops) dtypes.insert(data_type::f8_e5m2);
if (has_f8_e4m3_binary_postops) dtypes.insert(data_type::f8_e4m3);

io_ = io_mdt_helper(this, jpp.isa, dtypes, {}, io_tail_conf, io_bf16_conf,
{}, utils::nullopt, io_fp8_conf);

if (jpp.with_postops) {
static constexpr bool preserve_gpr = true;
static constexpr bool preserve_vmm = true;
static constexpr bool use_exact_tail_scalar_bcast = false;

const binary_injector::rhs_arg_static_params_t rhs_sp {
static_cast<std::size_t>(this->xmm4.getIdx()), this->r14,
this->r15, this->r13, preserve_gpr, preserve_vmm,
GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig),
memory_desc_wrapper(jpp.tag_kind == jit_memory_tag_kind_t::ncsp
? jpp.tmp_md
: *dst_md),
static_cast<size_t>(tail_size), k_c_tail_mask,
use_exact_tail_scalar_bcast};

std::shared_ptr<fp8_emulation_e5m2_t> f8_e5m2_emu;
auto f8_e5m2_io_helper = io_.at(data_type::f8_e5m2);
if (f8_e5m2_io_helper)
f8_e5m2_emu = std::dynamic_pointer_cast<fp8_emulation_e5m2_t>(
f8_e5m2_io_helper->get_fp8_emu());

std::shared_ptr<fp8_emulation_e4m3_t> f8_e4m3_emu;
auto f8_e4m3_io_helper = io_.at(data_type::f8_e4m3);
if (f8_e4m3_io_helper)
f8_e4m3_emu = std::dynamic_pointer_cast<fp8_emulation_e4m3_t>(
f8_e4m3_io_helper->get_fp8_emu());

const binary_injector::static_params_t bsp {reg_param,
get_supported_bcast_strategies(), rhs_sp, f8_e5m2_emu.get(),
f8_e4m3_emu.get()};

postops_injector_
= utils::make_unique<injector::jit_uni_postops_injector_t<isa>>(
this, jpp.post_ops, bsp);
}
}

static status_t set_binary_postops_formats(
Expand Down Expand Up @@ -1525,11 +1526,9 @@ void jit_uni_pool_kernel<isa>::generate() {

this->postamble();

io_.prepare_table_fp8();
if (jpp.with_eltwise && postops_injector_)
postops_injector_->prepare_table(/* generate = */ true);
if (f8_e5m2_emu_) f8_e5m2_emu_->prepare_table();
if (f8_e4m3_emu_) f8_e4m3_emu_->prepare_table();
io_.prepare_table_fp8();
}

template struct jit_uni_pool_kernel<sse41>;
Expand Down
4 changes: 1 addition & 3 deletions src/cpu/x64/jit_uni_pool_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,11 +254,9 @@ struct jit_uni_pool_kernel : public jit_generator {
return jpp.is_fp8 && is_superset(isa, avx512_core_fp16);
}

std::unique_ptr<fp8_emulation_e5m2_t> f8_e5m2_emu_;
std::unique_ptr<fp8_emulation_e4m3_t> f8_e4m3_emu_;
io::jit_io_multi_dt_helper_t<Vmm> io_;
std::unique_ptr<injector::jit_uni_postops_injector_t<isa>>
postops_injector_;
io::jit_io_multi_dt_helper_t<Vmm> io_;
};

} // namespace x64
Expand Down
10 changes: 8 additions & 2 deletions src/cpu/x64/utils/jit_io_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,14 @@ jit_io_helper_t<Vmm>::jit_io_helper_t(jit_generator *host, const cpu_isa_t &isa,
assert(fp8_conf.has_value() && "Config for fp8 emulation is not set.");
switch (data_type_) {
case data_type::f8_e5m2:
fp8_emu_ = utils::make_unique<fp8_emulation_e5m2_t>(host_,
fp8_emu_ = std::make_shared<fp8_emulation_e5m2_t>(host_,
fp8_conf->fp8_emu_reserv_1_,
fp8_conf->fp8_emu_reserv_2_,
fp8_conf->fp8_emu_reserv_3_, fp8_conf->kmask_aux_,
fp8_conf->reg_tmp_);
break;
case data_type::f8_e4m3:
fp8_emu_ = utils::make_unique<fp8_emulation_e4m3_t>(host_,
fp8_emu_ = std::make_shared<fp8_emulation_e4m3_t>(host_,
fp8_conf->fp8_emu_reserv_1_,
fp8_conf->fp8_emu_reserv_2_,
fp8_conf->fp8_emu_reserv_3_,
Expand Down Expand Up @@ -743,6 +743,12 @@ void jit_io_helper_t<Vmm>::merge_interleaved_to_plain(
host_->vperm2i128(ymm_odd, ymm_aux0, ymm_aux1, 0x31);
}

template <typename Vmm>
std::shared_ptr<fp8_emulation_base_t>
jit_io_helper_t<Vmm>::get_fp8_emu() const {
return fp8_emu_;
}

template <typename Vmm>
void jit_io_helper_t<Vmm>::store(const Vmm &src_raw_vmm,
const Xbyak::Address &dst_raw_addr, const bool tail) {
Expand Down
5 changes: 3 additions & 2 deletions src/cpu/x64/utils/jit_io_helper.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2021-2024 Intel Corporation
* Copyright 2021-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -199,6 +199,7 @@ class jit_io_helper_t {
const Vmm &dst_even_vmm, const Vmm &dst_odd_vmm);
void merge_interleaved_to_plain(
const Vmm &vmm_even, const Vmm &vmm_odd, const Vmm &vmm_aux0);
std::shared_ptr<fp8_emulation_base_t> get_fp8_emu() const;

private:
bool is_data_type_supported(const data_type_t dt);
Expand Down Expand Up @@ -242,7 +243,7 @@ class jit_io_helper_t {
const bool f16_supported_;
const bool fp8_supported_;
std::unique_ptr<bf16_emulation_t> bf16_emu_;
std::unique_ptr<fp8_emulation_base_t> fp8_emu_;
std::shared_ptr<fp8_emulation_base_t> fp8_emu_;
const io_conf_t io_conf_;
const utils::optional_t<io_tail_conf_t> tail_conf_;
const utils::optional_t<io_emu_bf16_conf_t> bf16_conf_;
Expand Down

0 comments on commit 32be96c

Please sign in to comment.