Skip to content

Commit a68b8ae

Browse files
tczeszundensamoilov
authored andcommitted
x64: conv: avoid overflows and add limit for huge spatial sizes
1 parent d723449 commit a68b8ae

5 files changed

+63
-14
lines changed

src/cpu/x64/jit_avx2_conv_kernel_f32.cpp

+21-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2016-2024 Intel Corporation
2+
* Copyright 2016-2025 Intel Corporation
33
* Copyright 2018 YANDEX LLC
44
*
55
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -1459,6 +1459,26 @@ status_t jit_avx2_conv_bwd_weights_kernel_f32::init_conf(jit_conv_conf_t &jcp,
14591459
jcp.nb_oc = div_up(jcp.oc, jcp.oc_block);
14601460
jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
14611461

1462+
jcp.typesize_in = types::data_type_size(src_d.data_type());
1463+
jcp.typesize_out = types::data_type_size(diff_dst_d.data_type());
1464+
1465+
const bool is_src_layout_blocked = jcp.src_tag == dat_tag_nCx8c;
1466+
const bool is_dst_layout_blocked = jcp.dst_tag == dat_tag_nCx8c;
1467+
1468+
dim_t src_size = static_cast<dim_t>(jcp.mb)
1469+
* (is_src_layout_blocked ? rnd_up(jcp.ic, jcp.ic_block) : jcp.ic)
1470+
* jcp.id * jcp.ih * jcp.iw * jcp.typesize_in;
1471+
1472+
VDISPATCH_CONV_IC(src_size <= INT_MAX, VERBOSE_UNSUPPORTED_FEATURE,
1473+
"src size > INT_MAX is not supported");
1474+
1475+
dim_t diff_dst_size = static_cast<dim_t>(jcp.mb)
1476+
* (is_dst_layout_blocked ? rnd_up(jcp.oc, jcp.oc_block) : jcp.oc)
1477+
* jcp.id * jcp.ih * jcp.iw * jcp.typesize_in;
1478+
1479+
VDISPATCH_CONV_IC(diff_dst_size <= INT_MAX, VERBOSE_UNSUPPORTED_FEATURE,
1480+
"diff_dst size > INT_MAX is not supported");
1481+
14621482
return status::success;
14631483
}
14641484

src/cpu/x64/jit_avx512_common_conv_kernel.cpp

+14-4
Original file line numberDiff line numberDiff line change
@@ -4118,14 +4118,24 @@ status_t jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf(
41184118
jcp.typesize_in = typesize;
41194119
jcp.typesize_out = typesize;
41204120

4121+
dim_t src_size = static_cast<dim_t>(jcp.mb)
4122+
* (is_data_layout_nxc ? jcp.ic : rnd_up(jcp.ic, jcp.ic_block))
4123+
* jcp.id * jcp.ih * jcp.iw * jcp.typesize_in;
4124+
4125+
VDISPATCH_CONV_IC(src_size <= INT_MAX, VERBOSE_UNSUPPORTED_FEATURE,
4126+
"src size > INT_MAX is not supported");
4127+
4128+
dim_t diff_dst_size = static_cast<dim_t>(jcp.mb)
4129+
* (is_data_layout_nxc ? jcp.oc : rnd_up(jcp.oc, jcp.oc_block))
4130+
* jcp.id * jcp.ih * jcp.iw * jcp.typesize_in;
4131+
4132+
VDISPATCH_CONV_IC(diff_dst_size <= INT_MAX, VERBOSE_UNSUPPORTED_FEATURE,
4133+
"diff_dst size > INT_MAX is not supported");
4134+
41214135
bool use_nxc_harness = false;
41224136
if (is_data_layout_nxc) {
41234137
dim_t kernel_size = static_cast<dim_t>(jcp.ic) * jcp.oc * jcp.kd
41244138
* jcp.kh * jcp.kw * jcp.typesize_out;
4125-
dim_t src_size = static_cast<dim_t>(jcp.mb) * jcp.ic * jcp.id * jcp.ih
4126-
* jcp.iw * jcp.typesize_in;
4127-
dim_t diff_dst_size = static_cast<dim_t>(jcp.mb) * jcp.oc * jcp.id
4128-
* jcp.ih * jcp.iw * jcp.typesize_in;
41294139
dim_t data_size = src_size + diff_dst_size;
41304140

41314141
// The advantage of the nxc kernel is cache traversal, this comes at a

src/cpu/x64/jit_avx512_core_bf16_conv_kernel.cpp

+9-2
Original file line numberDiff line numberDiff line change
@@ -4449,8 +4449,15 @@ status_t jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::init_conf(
44494449
&& jcp.oc <= diff_weights_d.padded_dims()[with_groups + 0];
44504450
if (!args_ok) return status::unimplemented;
44514451

4452-
int inp_row_size = jcp.ic_block * jcp.tr_iw * jcp.typesize_in;
4453-
int out_row_size = jcp.oc_block * jcp.tr_ow * jcp.typesize_in;
4452+
const auto inp_row_size
4453+
= static_cast<size_t>(jcp.ic_block) * jcp.tr_iw * jcp.typesize_in;
4454+
VDISPATCH_CONV_IC(inp_row_size <= INT_MAX, VERBOSE_UNSUPPORTED_FEATURE,
4455+
"inp_row_size > INT_MAX is not supported");
4456+
const auto out_row_size
4457+
= static_cast<size_t>(jcp.oc_block) * jcp.tr_ow * jcp.typesize_in;
4458+
VDISPATCH_CONV_IC(out_row_size <= INT_MAX, VERBOSE_UNSUPPORTED_FEATURE,
4459+
"out_row_size > INT_MAX is not supported");
4460+
44544461
int full_spat_min_h_block_size
44554462
= nstl::max(1, nstl::max(jcp.b_pad, jcp.t_pad));
44564463
int full_spat_working_set_size

src/cpu/x64/jit_brgemm_conv_bwd_w.cpp

+14-3
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,20 @@ status_t brgemm_convolution_bwd_weights_t::pd_t::init(engine_t *engine) {
139139
brgattr.max_top_vpad = 0;
140140
brgattr.max_bottom_vpad = 0;
141141

142-
brgattr.LDA2 = jcp_.tr_iw * jcp_.ih_block * jcp_.id;
143-
brgattr.LDB2
144-
= jcp_.tr_ow * jcp_.oc_block * jcp_.oh_block * jcp_.od;
142+
const auto lda2_size = static_cast<size_t>(jcp_.tr_iw)
143+
* jcp_.ih_block * jcp_.id;
144+
VDISPATCH_CONV_IC(lda2_size <= INT_MAX,
145+
VERBOSE_UNSUPPORTED_FEATURE,
146+
"lda2_size > INT_MAX is not supported");
147+
brgattr.LDA2 = lda2_size;
148+
149+
const auto ldb2_size = static_cast<size_t>(jcp_.tr_ow)
150+
* jcp_.oc_block * jcp_.oh_block * jcp_.od;
151+
VDISPATCH_CONV_IC(ldb2_size <= INT_MAX,
152+
VERBOSE_UNSUPPORTED_FEATURE,
153+
"ldb2_size > INT_MAX is not supported");
154+
brgattr.LDB2 = ldb2_size;
155+
145156
brgattr.LDC2_M = jcp_.oc_block * jcp_.kd * jcp_.kh * jcp_.kw;
146157
brgattr.LDC2_N = jcp_.nb_ic * jcp_.ic_block * jcp_.oc_block
147158
* jcp_.kd * jcp_.kh * jcp_.kw;

src/cpu/x64/jit_brgemm_conv_utils.cpp

+5-4
Original file line numberDiff line numberDiff line change
@@ -3306,9 +3306,10 @@ status_t init_conf_bwd_w(jit_brgemm_conv_conf_t &jcp,
33063306
jcp.tr_diff_dst_buf_count = jcp.global_transpose
33073307
? jcp.nthr_mb * jcp.nb_oc * jcp.ngroups
33083308
: jcp.nthr;
3309-
jcp.tr_src_block_size = jcp.tr_iw * jcp.ic_block * jcp.ih_block * jcp.id;
3310-
jcp.tr_diff_dst_block_size
3311-
= jcp.tr_ow * jcp.oc_block * jcp.oh_block * jcp.od;
3309+
jcp.tr_src_block_size = static_cast<size_t>(jcp.tr_iw) * jcp.ic_block
3310+
* jcp.ih_block * jcp.id;
3311+
jcp.tr_diff_dst_block_size = static_cast<size_t>(jcp.tr_ow) * jcp.oc_block
3312+
* jcp.oh_block * jcp.od;
33123313

33133314
jcp.tr_src_buf_size = jcp.tr_src_block_size
33143315
* (jcp.global_transpose ? 1 : jcp.nb_ic_blocking);
@@ -3368,7 +3369,7 @@ status_t init_scratchpad_bwd_w(memory_tracking::registrar_t &scratchpad,
33683369
// (jcp.tr_diff_dst_buf_size + jcp.tr_iw * jcp.oc_block)
33693370
const auto tr_diff_dst_size
33703371
= jcp.tr_diff_dst_buf_count * jcp.tr_diff_dst_buf_size
3371-
+ jcp.tr_iw * jcp.oc_block;
3372+
+ static_cast<size_t>(jcp.tr_iw) * jcp.oc_block;
33723373

33733374
const size_t min_align = 64;
33743375
scratchpad.book(

0 commit comments

Comments
 (0)