@@ -77,8 +77,8 @@ struct cudnn_gemm_inner_product_fwd_impl_t
77
77
bool need_reorder_;
78
78
79
79
virtual status_t init (impl::engine_t *, inner_product_pd_t *pd,
80
- bool with_relu, bool with_eltwise, bool with_sum,
81
- bool need_reorder ) override {
80
+ bool with_relu, bool with_eltwise, bool with_sum, bool need_reorder,
81
+ bool use_f32_sum ) override {
82
82
need_reorder_ = need_reorder;
83
83
// GEMM is column major, here the data is row major.
84
84
// By switching the weight and source we convert the row major to
@@ -121,8 +121,10 @@ struct cudnn_gemm_inner_product_fwd_impl_t
121
121
use_acc_dst_ = ((pd->dst_md ()->data_type == data_type::s8)
122
122
|| (with_bias_
123
123
&& pd->weights_md (1 )->data_type
124
- != pd->dst_md ()->data_type ));
124
+ != pd->dst_md ()->data_type )
125
+ || use_f32_sum);
125
126
with_sum_ = with_sum;
127
+ with_f32_sum_ = use_f32_sum;
126
128
// scaling factor to add the previous destination value to the current
127
129
// computation. This is equivalent of
128
130
sum_scale_ = sum_scale (pd);
@@ -154,12 +156,23 @@ struct cudnn_gemm_inner_product_fwd_impl_t
154
156
155
157
if (with_bias_) {
156
158
CHECK (convert_data_type (pd->weights_md (1 ), &data_types_[io::bia]));
159
+
157
160
// format is always nchw
158
161
set_bias_dims (CUDNN_TENSOR_NCHW, ndims_, pd->OC ());
159
162
160
163
CHECK (create_and_set_tensor_descriptor (&tensor_descs_[io::bia],
161
164
data_types_[io::bia], ndims_, dims_[io::bia],
162
165
strides_[io::bia]));
166
+
167
+ if (with_f32_sum_) {
168
+ pd->scratchpad_registry ().registrar ().book (
169
+ memory_tracking::names::key_iprod_bias_bf16_convert_wsp,
170
+ memory_desc_wrapper (pd->weights_md (1 )).nelems (),
171
+ types::data_type_size (data_type::f32));
172
+ CHECK (create_and_set_tensor_descriptor (&bias_f32_desc_,
173
+ CUDNN_DATA_FLOAT, ndims_, dims_[io::bia],
174
+ strides_[io::bia]));
175
+ }
163
176
}
164
177
if (use_acc_dst_) {
165
178
pd->scratchpad_registry ().registrar ().book (
@@ -178,10 +191,10 @@ struct cudnn_gemm_inner_product_fwd_impl_t
178
191
179
192
void execute (cudnnHandle_t cudnn_handle, cublasHandle_t cublas_handle,
180
193
const std::vector<void *> &args) const override {
181
- assert (args.size () == 9 );
194
+ assert (args.size () == 10 );
182
195
auto x = args[0 ], w = args[1 ], b = args[2 ], y = args[3 ],
183
196
workspace = args[4 ], src_scale = args[6 ], wei_scale = args[7 ],
184
- dst_scale = args[8 ];
197
+ dst_scale = args[8 ], bias_f32 = args[ 9 ] ;
185
198
auto w_arg = w;
186
199
if (need_reorder_) {
187
200
void *transformed_w = args[5 ];
@@ -222,8 +235,18 @@ struct cudnn_gemm_inner_product_fwd_impl_t
222
235
223
236
if (with_bias_) {
224
237
float alpha = 1 .0f ;
225
- CUDNN_EXECUTE_FUNC (cudnnAddTensor, cudnn_handle, &alpha,
226
- tensor_descs_[io::bia], b, &alpha, y_acc_desc_, y_dst);
238
+ float beta = 0 .f ;
239
+ auto bias = b;
240
+ auto bias_desc = tensor_descs_[io::bia];
241
+ if (with_f32_sum_) {
242
+ cudnnTransformTensor (cudnn_handle, &alpha,
243
+ tensor_descs_[io::bia], b, &beta, bias_f32_desc_,
244
+ bias_f32);
245
+ bias = bias_f32;
246
+ bias_desc = bias_f32_desc_;
247
+ }
248
+ CUDNN_EXECUTE_FUNC (cudnnAddTensor, cudnn_handle, &alpha, bias_desc,
249
+ bias, &alpha, y_acc_desc_, y_dst);
227
250
}
228
251
if (with_eltwise_) {
229
252
CUDNN_EXECUTE_FUNC (cudnnActivationForward, cudnn_handle, act_desc_,
@@ -281,7 +304,7 @@ struct cudnn_gemm_inner_product_bwd_data_impl_t
281
304
282
305
virtual status_t init (impl::engine_t *, inner_product_pd_t *pd,
283
306
bool /* with_relu*/ , bool /* with_eltwise*/ , bool /* with_sum */ ,
284
- bool need_reorder) override {
307
+ bool need_reorder, bool /* use_f32_sum */ ) override {
285
308
need_reorder_ = need_reorder;
286
309
287
310
// GEMM is column major, here the data is row major.
@@ -365,7 +388,7 @@ struct cudnn_gemm_inner_product_bwd_weights_impl_t
365
388
}
366
389
virtual status_t init (impl::engine_t *engine, inner_product_pd_t *pd,
367
390
bool /* with_relu*/ , bool /* with_eltwise*/ , bool /* with_sum */ ,
368
- bool need_reorder) override {
391
+ bool need_reorder, bool /* use_f32_sum */ ) override {
369
392
need_reorder_ = need_reorder;
370
393
with_bias_ = pd->with_bias ();
371
394
0 commit comments