Skip to content

Commit 7841ac5

Browse files
committed
gpu: nvidia: ip: adjust benchdnn error threshold
1 parent ea6c0b7 commit 7841ac5

8 files changed

+79
-20
lines changed

src/gpu/nvidia/README.md

+8
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,14 @@ limitations when using Nvidia backend for eltwise primitive:
215215
The inner product primitives is an implementation of matrix multiplication plus
216216
bias activation. There are two implementation of inner product in cuDNN backend.
217217

218+
With `sum` post-op, the accumulation mode attribute affects behaviour as
219+
follows:
220+
- `relaxed`: Uses GEMM’s beta parameter for a fused, optimised sum post-op but
221+
may reduce output precision for large `f16` inputs.
222+
- `strict` (default): Converts GEMM output to `f32`, performs sum as a separate
223+
operation, then converts it back to the original type. This is more precise
224+
but less performant.
225+
218226
#### Using GEMM
219227

220228
The default backend for inner product is the gemm backend using `cublasGemmEx`

src/gpu/nvidia/cudnn_conv_inner_product.hpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,8 @@ struct cudnn_conv_inner_product_fwd_t : public cudnn_inner_product_fwd_t {
9191
new cudnn_conv_inner_product_fwd_impl_t());
9292

9393
auto st = inner_product_impl_->init(engine, this, with_relu(),
94-
with_eltwise(), with_sum(), use_fused_path_for_blocking);
94+
with_eltwise(), with_sum(), use_fused_path_for_blocking,
95+
false);
9596
return st;
9697
}
9798
bool with_eltwise() const {
@@ -250,7 +251,7 @@ struct cudnn_conv_inner_product_bwd_data_t
250251
new cudnn_conv_inner_product_bwd_data_impl_t());
251252

252253
return inner_product_impl_->init(
253-
engine, this, false, false, false, false);
254+
engine, this, false, false, false, false, false);
254255
}
255256

256257
status_t set_default_params() {
@@ -341,7 +342,7 @@ struct cudnn_conv_inner_product_bwd_weights_t
341342
new cudnn_conv_inner_product_bwd_weights_impl_t());
342343

343344
return inner_product_impl_->init(
344-
engine, this, false, false, false, false);
345+
engine, this, false, false, false, false, false);
345346
}
346347

347348
status_t set_default_params() {

src/gpu/nvidia/cudnn_conv_inner_product_impl.hpp

+5-3
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ struct cudnn_conv_inner_product_fwd_impl_t
117117
}
118118
virtual status_t init(impl::engine_t *engine, inner_product_pd_t *pd,
119119
bool with_relu, bool with_eltwise, bool with_sum,
120-
bool use_fuse_path_for_blocking) override {
120+
bool use_fuse_path_for_blocking, bool /* use_f32_sum */) override {
121121
with_bias_ = pd->with_bias();
122122
with_relu_ = with_relu;
123123
with_eltwise_ = with_eltwise;
@@ -424,7 +424,8 @@ struct cudnn_conv_inner_product_bwd_data_impl_t
424424
cudnnTensorFormat_t diff_source_format_;
425425
virtual status_t init(impl::engine_t *engine, inner_product_pd_t *pd,
426426
bool /*with_relu*/, bool /*with_eltwise*/, bool /*with_sum */,
427-
bool /*using_fused_path_for_blocking*/) override {
427+
bool /*using_fused_path_for_blocking*/,
428+
bool /* use_f32_sum */) override {
428429
// Pad out the dimensions to 4
429430
if (pd->ndims() > CUDNN_DIM_MAX || pd->ndims() < 2) {
430431
return status::invalid_arguments;
@@ -575,7 +576,8 @@ struct cudnn_conv_inner_product_bwd_weights_impl_t
575576

576577
virtual status_t init(impl::engine_t *engine, inner_product_pd_t *pd,
577578
bool /*with_relu*/, bool /*with_eltwise*/, bool /*with_sum */,
578-
bool /*using_fused_path_for_blocking*/) override {
579+
bool /*using_fused_path_for_blocking*/,
580+
bool /* use_f32_sum */) override {
579581
// If any of the dimensions are 0 we should not continue with creating
580582
// cudnn descriptors
581583
with_bias_ = pd->with_bias();

src/gpu/nvidia/cudnn_gemm_inner_product.hpp

+7-3
Original file line numberDiff line numberDiff line change
@@ -222,10 +222,14 @@ struct cudnn_gemm_inner_product_fwd_t : public cudnn_inner_product_fwd_t {
222222
&& (gemm_compatible || need_reorder);
223223
if (!ok) return status::unimplemented;
224224

225+
const bool is_relaxed_acc_mode
226+
= attr()->acc_mode_ == dnnl_accumulation_mode_relaxed;
227+
const bool use_f32_sum = with_sum && !is_relaxed_acc_mode;
228+
225229
inner_product_impl_.reset(
226230
new cudnn_gemm_inner_product_fwd_impl_t());
227231
return inner_product_impl_->init(engine, this, with_eltwise,
228-
with_eltwise, with_sum, need_reorder);
232+
with_eltwise, with_sum, need_reorder, use_f32_sum);
229233
}
230234

231235
status_t set_default_params() {
@@ -289,7 +293,7 @@ struct cudnn_gemm_inner_product_bwd_data_t
289293
new cudnn_gemm_inner_product_bwd_data_impl_t());
290294

291295
return inner_product_impl_->init(
292-
engine, this, false, false, false, need_reorder);
296+
engine, this, false, false, false, need_reorder, false);
293297
}
294298

295299
status_t set_default_params() {
@@ -345,7 +349,7 @@ struct cudnn_gemm_inner_product_bwd_weights_t
345349
inner_product_impl_.reset(
346350
new cudnn_gemm_inner_product_bwd_weights_impl_t());
347351
return inner_product_impl_->init(
348-
engine, this, false, false, false, need_reorder);
352+
engine, this, false, false, false, need_reorder, false);
349353
}
350354

351355
status_t set_default_params() {

src/gpu/nvidia/cudnn_gemm_inner_product_impl.hpp

+38-9
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,12 @@ struct cudnn_gemm_inner_product_fwd_impl_t
7575
bool use_acc_dst_;
7676
cudnnTensorDescriptor_t y_acc_desc_;
7777
bool need_reorder_;
78+
cudnnTensorDescriptor_t bias_f32_desc_;
79+
bool with_f32_sum_ = false;
7880

7981
virtual status_t init(impl::engine_t *, inner_product_pd_t *pd,
80-
bool with_relu, bool with_eltwise, bool with_sum,
81-
bool need_reorder) override {
82+
bool with_relu, bool with_eltwise, bool with_sum, bool need_reorder,
83+
bool use_f32_sum) override {
8284
need_reorder_ = need_reorder;
8385
// GEMM is column major, here the data is row major.
8486
// By switching the weight and source we convert the row major to
@@ -121,8 +123,10 @@ struct cudnn_gemm_inner_product_fwd_impl_t
121123
use_acc_dst_ = ((pd->dst_md()->data_type == data_type::s8)
122124
|| (with_bias_
123125
&& pd->weights_md(1)->data_type
124-
!= pd->dst_md()->data_type));
126+
!= pd->dst_md()->data_type)
127+
|| use_f32_sum);
125128
with_sum_ = with_sum;
129+
with_f32_sum_ = use_f32_sum;
126130
// scaling factor to add the previous destination value to the current
127131
// computation. This is equivalent of
128132
sum_scale_ = sum_scale(pd);
@@ -154,12 +158,23 @@ struct cudnn_gemm_inner_product_fwd_impl_t
154158

155159
if (with_bias_) {
156160
CHECK(convert_data_type(pd->weights_md(1), &data_types_[io::bia]));
161+
157162
// format is always nchw
158163
set_bias_dims(CUDNN_TENSOR_NCHW, ndims_, pd->OC());
159164

160165
CHECK(create_and_set_tensor_descriptor(&tensor_descs_[io::bia],
161166
data_types_[io::bia], ndims_, dims_[io::bia],
162167
strides_[io::bia]));
168+
169+
if (with_f32_sum_) {
170+
pd->scratchpad_registry().registrar().book(
171+
memory_tracking::names::key_iprod_bias_bf16_convert_wsp,
172+
memory_desc_wrapper(pd->weights_md(1)).nelems(),
173+
types::data_type_size(data_type::f32));
174+
CHECK(create_and_set_tensor_descriptor(&bias_f32_desc_,
175+
CUDNN_DATA_FLOAT, ndims_, dims_[io::bia],
176+
strides_[io::bia]));
177+
}
163178
}
164179
if (use_acc_dst_) {
165180
pd->scratchpad_registry().registrar().book(
@@ -178,10 +193,10 @@ struct cudnn_gemm_inner_product_fwd_impl_t
178193

179194
void execute(cudnnHandle_t cudnn_handle, cublasHandle_t cublas_handle,
180195
const std::vector<void *> &args) const override {
181-
assert(args.size() == 9);
196+
assert(args.size() == 10);
182197
auto x = args[0], w = args[1], b = args[2], y = args[3],
183198
workspace = args[4], src_scale = args[6], wei_scale = args[7],
184-
dst_scale = args[8];
199+
dst_scale = args[8], bias_f32 = args[9];
185200
auto w_arg = w;
186201
if (need_reorder_) {
187202
void *transformed_w = args[5];
@@ -222,8 +237,18 @@ struct cudnn_gemm_inner_product_fwd_impl_t
222237

223238
if (with_bias_) {
224239
float alpha = 1.0f;
225-
CUDNN_EXECUTE_FUNC(cudnnAddTensor, cudnn_handle, &alpha,
226-
tensor_descs_[io::bia], b, &alpha, y_acc_desc_, y_dst);
240+
float beta = 0.f;
241+
auto bias = b;
242+
auto bias_desc = tensor_descs_[io::bia];
243+
if (with_f32_sum_) {
244+
cudnnTransformTensor(cudnn_handle, &alpha,
245+
tensor_descs_[io::bia], b, &beta, bias_f32_desc_,
246+
bias_f32);
247+
bias = bias_f32;
248+
bias_desc = bias_f32_desc_;
249+
}
250+
CUDNN_EXECUTE_FUNC(cudnnAddTensor, cudnn_handle, &alpha, bias_desc,
251+
bias, &alpha, y_acc_desc_, y_dst);
227252
}
228253
if (with_eltwise_) {
229254
CUDNN_EXECUTE_FUNC(cudnnActivationForward, cudnn_handle, act_desc_,
@@ -271,6 +296,10 @@ struct cudnn_gemm_inner_product_fwd_impl_t
271296

272297
return status::success;
273298
}
299+
300+
~cudnn_gemm_inner_product_fwd_impl_t() {
301+
if (with_f32_sum_) { cudnnDestroyTensorDescriptor(bias_f32_desc_); }
302+
}
274303
};
275304

276305
struct cudnn_gemm_inner_product_bwd_data_impl_t
@@ -281,7 +310,7 @@ struct cudnn_gemm_inner_product_bwd_data_impl_t
281310

282311
virtual status_t init(impl::engine_t *, inner_product_pd_t *pd,
283312
bool /*with_relu*/, bool /*with_eltwise*/, bool /*with_sum */,
284-
bool need_reorder) override {
313+
bool need_reorder, bool /* use_f32_sum */) override {
285314
need_reorder_ = need_reorder;
286315

287316
// GEMM is column major, here the data is row major.
@@ -365,7 +394,7 @@ struct cudnn_gemm_inner_product_bwd_weights_impl_t
365394
}
366395
virtual status_t init(impl::engine_t *engine, inner_product_pd_t *pd,
367396
bool /*with_relu*/, bool /*with_eltwise*/, bool /*with_sum */,
368-
bool need_reorder) override {
397+
bool need_reorder, bool /* use_f32_sum */) override {
369398
need_reorder_ = need_reorder;
370399
with_bias_ = pd->with_bias();
371400

src/gpu/nvidia/cudnn_inner_product.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ status_t cudnn_inner_product_fwd_t::execute(const exec_ctx_t &ctx) const {
4949
memory_tracking::names::key_iprod_int_dat_in_acc_dt);
5050
auto arg_spacial_scratch
5151
= CTX_SCRATCH_SYCL_MEMORY(memory_tracking::names::key_none);
52+
auto arg_f32_bias_scratch = CTX_SCRATCH_SYCL_MEMORY(
53+
memory_tracking::names::key_iprod_bias_bf16_convert_wsp);
5254
compat::host_task(cgh, [=, this](const compat::interop_handle &ih) {
5355
auto &sycl_engine = *utils::downcast<nvidia::engine_t *>(
5456
cuda_stream->engine());
@@ -72,6 +74,7 @@ status_t cudnn_inner_product_fwd_t::execute(const exec_ctx_t &ctx) const {
7274
args.push_back(arg_src_scale.get_native_pointer(ih));
7375
args.push_back(arg_wei_scale.get_native_pointer(ih));
7476
args.push_back(arg_dst_scale.get_native_pointer(ih));
77+
args.push_back(arg_f32_bias_scratch.get_native_pointer(ih));
7578

7679
pd()->inner_product_impl_->execute(
7780
cudnn_handle, cublas_handle, args);

src/gpu/nvidia/cudnn_inner_product_impl.hpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -146,12 +146,14 @@ struct cudnn_inner_product_impl_base_t {
146146
virtual status_t init(impl::engine_t * /*engine*/,
147147
inner_product_pd_t * /*pd*/, bool /*with_relu*/,
148148
bool /*with_eltwise*/, bool /*with_sum */,
149-
bool /*using_fused_path_for_blocking*/)
149+
bool /*using_fused_path_for_blocking*/, bool /* use_f32_sum */)
150150
= 0;
151151

152152
virtual void execute(cudnnHandle_t /*handle*/,
153153
cublasHandle_t /*cublas_handle*/,
154154
const std::vector<void *> & /*args*/) const = 0;
155+
156+
virtual ~cudnn_inner_product_impl_base_t() = default;
155157
};
156158

157159
struct cudnn_inner_product_fwd_base_t : public cudnn_inner_product_impl_base_t {

tests/benchdnn/ip/ip.cpp

+11-1
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,17 @@ void skip_invalid_prb(const prb_t *prb, res_t *res) {}
278278

279279
void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind,
280280
const args_t &ref_args) {
281-
cmp.set_threshold(0.f);
281+
// The nvidia implementation has different precision guarantees in some cases
282+
// for large problems with post-op sum
283+
if (is_nvidia_gpu()
284+
&& prb->attr.post_ops.find(attr_t::post_ops_t::kind_t::SUM) != -1
285+
&& prb->dst_dt() == dnnl_f16 && (prb->dir & FLAG_FWD)
286+
&& prb->attr.acc_mode == dnnl_accumulation_mode_relaxed) {
287+
const float trh = epsilon_dt(prb->dt[2]);
288+
cmp.set_threshold(trh);
289+
} else {
290+
cmp.set_threshold(0.f);
291+
}
282292
}
283293

284294
std::vector<int> supported_exec_args(dir_t dir) {

0 commit comments

Comments
 (0)