Skip to content

Commit

Permalink
cpu: x64: pooling: init scratchpad refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
asimonov1 committed Jan 26, 2025
1 parent 53082c3 commit f1bb9e5
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 30 deletions.
48 changes: 27 additions & 21 deletions src/cpu/x64/jit_uni_pool_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,8 @@ static status_t set_binary_postops_formats(
}

template <cpu_isa_t isa>
status_t jit_uni_pool_kernel<isa>::init_conf(jit_pool_conf_t &jpp,
memory_tracking::registrar_t &scratchpad, primitive_attr_t &attr,
const pooling_pd_t *ppd) {
status_t jit_uni_pool_kernel<isa>::init_conf(
jit_pool_conf_t &jpp, primitive_attr_t &attr, const pooling_pd_t *ppd) {

const auto &pd = *ppd->desc();
const memory_desc_wrapper src_d(
Expand Down Expand Up @@ -410,6 +409,31 @@ status_t jit_uni_pool_kernel<isa>::init_conf(jit_pool_conf_t &jpp,
jpp.ur_bc_tail = 0;
}

jpp.f32_accum_block_size = jpp.ur_bc * jpp.c_block;
if (jpp.needs_f32_accum_for_bf16) {
assert(memory_desc_wrapper(jpp.tmp_md).is_zero()
&& (fmt_tag == nspc_fmt_tag || fmt_tag == blocked_fmt_tag));

dims_t dims {};
utils::array_copy(dims, src_d.dims(), ndims);

const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc);
dims[0] = nstl::min(dnnl_get_max_threads(), jpp.mb * nb2_c);
dims[1] = jpp.f32_accum_block_size;

memory_desc_init_by_tag(
jpp.tmp_md, ndims, dims, data_type::f32, fmt_tag);
}

jpp.post_ops = attr.post_ops_;

return status::success;
}

template <cpu_isa_t isa>
void jit_uni_pool_kernel<isa>::init_scratchpad(
const jit_pool_conf_t &jpp, memory_tracking::registrar_t &scratchpad) {

// scratchpad for c_block slice of input and/or output
using namespace memory_tracking::names;
const int nscr = nstl::min(dnnl_get_max_threads(), jpp.mb * jpp.nb_c);
Expand All @@ -427,28 +451,10 @@ status_t jit_uni_pool_kernel<isa>::init_conf(jit_pool_conf_t &jpp,
* nscr);
}

jpp.f32_accum_block_size = jpp.ur_bc * jpp.c_block;
if (jpp.needs_f32_accum_for_bf16) {
auto tmp_d = memory_desc_wrapper(jpp.tmp_md);
assert(tmp_d.is_zero()
&& (fmt_tag == nspc_fmt_tag || fmt_tag == blocked_fmt_tag));

dims_t dims {};
utils::array_copy(dims, src_d.dims(), ndims);

const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc);
dims[0] = nstl::min(dnnl_get_max_threads(), jpp.mb * nb2_c);
dims[1] = jpp.f32_accum_block_size;

memory_desc_init_by_tag(
jpp.tmp_md, ndims, dims, data_type::f32, fmt_tag);

scratchpad.book<char>(key_pool_src_f32_accum, tmp_d.size());
}

jpp.post_ops = attr.post_ops_;

return status::success;
}

static int reg_ind(int shift, int bc, int j, int ur_bc, int ur_w) noexcept {
Expand Down
6 changes: 4 additions & 2 deletions src/cpu/x64/jit_uni_pool_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,12 @@ struct jit_uni_pool_kernel : public jit_generator {

DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_pool_kernel)

static status_t init_conf(jit_pool_conf_t &jbp,
memory_tracking::registrar_t &scratchpad, primitive_attr_t &attr,
static status_t init_conf(jit_pool_conf_t &jpp, primitive_attr_t &attr,
const pooling_pd_t *ppd);

static void init_scratchpad(const jit_pool_conf_t &jpp,
memory_tracking::registrar_t &scratchpad);

private:
using Xmm = Xbyak::Xmm;
using Ymm = Xbyak::Ymm;
Expand Down
14 changes: 7 additions & 7 deletions src/cpu/x64/jit_uni_pooling.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2017-2024 Intel Corporation
* Copyright 2017-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -70,10 +70,10 @@ struct jit_uni_pooling_fwd_t : public primitive_t {
if (desc()->alg_kind == alg_kind::pooling_max && is_training)
init_default_ws();

auto scratchpad = scratchpad_registry().registrar();
CHECK(jit_uni_pool_kernel<isa>::init_conf(jpp_, attr_, this));

CHECK(jit_uni_pool_kernel<isa>::init_conf(
jpp_, scratchpad, attr_, this));
auto scratchpad = scratchpad_registry().registrar();
jit_uni_pool_kernel<isa>::init_scratchpad(jpp_, scratchpad);

return status::success;
}
Expand Down Expand Up @@ -146,10 +146,10 @@ struct jit_uni_pooling_bwd_t : public primitive_t {
compare_ws(hint_fwd_pd_), VERBOSE_WS_MISMATCH);
}

auto scratchpad = scratchpad_registry().registrar();
CHECK(jit_uni_pool_kernel<isa>::init_conf(jpp_, attr_, this));

CHECK(jit_uni_pool_kernel<isa>::init_conf(
jpp_, scratchpad, attr_, this));
auto scratchpad = scratchpad_registry().registrar();
jit_uni_pool_kernel<isa>::init_scratchpad(jpp_, scratchpad);

return status::success;
}
Expand Down

0 comments on commit f1bb9e5

Please sign in to comment.