@@ -75,10 +75,12 @@ struct cudnn_gemm_inner_product_fwd_impl_t
75
75
bool use_acc_dst_;
76
76
cudnnTensorDescriptor_t y_acc_desc_;
77
77
bool need_reorder_;
78
+ cudnnTensorDescriptor_t bias_f32_desc_;
79
+ bool with_f32_sum_ = false ;
78
80
79
81
virtual status_t init (impl::engine_t *, inner_product_pd_t *pd,
80
- bool with_relu, bool with_eltwise, bool with_sum,
81
- bool need_reorder ) override {
82
+ bool with_relu, bool with_eltwise, bool with_sum, bool need_reorder,
83
+ bool use_f32_sum ) override {
82
84
need_reorder_ = need_reorder;
83
85
// GEMM is column major, here the data is row major.
84
86
// By switching the weight and source we convert the row major to
@@ -121,8 +123,10 @@ struct cudnn_gemm_inner_product_fwd_impl_t
121
123
use_acc_dst_ = ((pd->dst_md ()->data_type == data_type::s8)
122
124
|| (with_bias_
123
125
&& pd->weights_md (1 )->data_type
124
- != pd->dst_md ()->data_type ));
126
+ != pd->dst_md ()->data_type )
127
+ || use_f32_sum);
125
128
with_sum_ = with_sum;
129
+ with_f32_sum_ = use_f32_sum;
126
130
// scaling factor to add the previous destination value to the current
127
131
// computation. This is equivalent of
128
132
sum_scale_ = sum_scale (pd);
@@ -154,12 +158,23 @@ struct cudnn_gemm_inner_product_fwd_impl_t
154
158
155
159
if (with_bias_) {
156
160
CHECK (convert_data_type (pd->weights_md (1 ), &data_types_[io::bia]));
161
+
157
162
// format is always nchw
158
163
set_bias_dims (CUDNN_TENSOR_NCHW, ndims_, pd->OC ());
159
164
160
165
CHECK (create_and_set_tensor_descriptor (&tensor_descs_[io::bia],
161
166
data_types_[io::bia], ndims_, dims_[io::bia],
162
167
strides_[io::bia]));
168
+
169
+ if (with_f32_sum_) {
170
+ pd->scratchpad_registry ().registrar ().book (
171
+ memory_tracking::names::key_iprod_bias_bf16_convert_wsp,
172
+ memory_desc_wrapper (pd->weights_md (1 )).nelems (),
173
+ types::data_type_size (data_type::f32));
174
+ CHECK (create_and_set_tensor_descriptor (&bias_f32_desc_,
175
+ CUDNN_DATA_FLOAT, ndims_, dims_[io::bia],
176
+ strides_[io::bia]));
177
+ }
163
178
}
164
179
if (use_acc_dst_) {
165
180
pd->scratchpad_registry ().registrar ().book (
@@ -178,10 +193,10 @@ struct cudnn_gemm_inner_product_fwd_impl_t
178
193
179
194
void execute (cudnnHandle_t cudnn_handle, cublasHandle_t cublas_handle,
180
195
const std::vector<void *> &args) const override {
181
- assert (args.size () == 9 );
196
+ assert (args.size () == 10 );
182
197
auto x = args[0 ], w = args[1 ], b = args[2 ], y = args[3 ],
183
198
workspace = args[4 ], src_scale = args[6 ], wei_scale = args[7 ],
184
- dst_scale = args[8 ];
199
+ dst_scale = args[8 ], bias_f32 = args[ 9 ] ;
185
200
auto w_arg = w;
186
201
if (need_reorder_) {
187
202
void *transformed_w = args[5 ];
@@ -222,8 +237,18 @@ struct cudnn_gemm_inner_product_fwd_impl_t
222
237
223
238
if (with_bias_) {
224
239
float alpha = 1 .0f ;
225
- CUDNN_EXECUTE_FUNC (cudnnAddTensor, cudnn_handle, &alpha,
226
- tensor_descs_[io::bia], b, &alpha, y_acc_desc_, y_dst);
240
+ float beta = 0 .f ;
241
+ auto bias = b;
242
+ auto bias_desc = tensor_descs_[io::bia];
243
+ if (with_f32_sum_) {
244
+ cudnnTransformTensor (cudnn_handle, &alpha,
245
+ tensor_descs_[io::bia], b, &beta, bias_f32_desc_,
246
+ bias_f32);
247
+ bias = bias_f32;
248
+ bias_desc = bias_f32_desc_;
249
+ }
250
+ CUDNN_EXECUTE_FUNC (cudnnAddTensor, cudnn_handle, &alpha, bias_desc,
251
+ bias, &alpha, y_acc_desc_, y_dst);
227
252
}
228
253
if (with_eltwise_) {
229
254
CUDNN_EXECUTE_FUNC (cudnnActivationForward, cudnn_handle, act_desc_,
@@ -271,6 +296,10 @@ struct cudnn_gemm_inner_product_fwd_impl_t
271
296
272
297
return status::success;
273
298
}
299
+
300
+ ~cudnn_gemm_inner_product_fwd_impl_t () {
301
+ if (with_f32_sum_) { cudnnDestroyTensorDescriptor (bias_f32_desc_); }
302
+ }
274
303
};
275
304
276
305
struct cudnn_gemm_inner_product_bwd_data_impl_t
@@ -281,7 +310,7 @@ struct cudnn_gemm_inner_product_bwd_data_impl_t
281
310
282
311
virtual status_t init (impl::engine_t *, inner_product_pd_t *pd,
283
312
bool /* with_relu*/ , bool /* with_eltwise*/ , bool /* with_sum */ ,
284
- bool need_reorder) override {
313
+ bool need_reorder, bool /* use_f32_sum */ ) override {
285
314
need_reorder_ = need_reorder;
286
315
287
316
// GEMM is column major, here the data is row major.
@@ -365,7 +394,7 @@ struct cudnn_gemm_inner_product_bwd_weights_impl_t
365
394
}
366
395
virtual status_t init (impl::engine_t *engine, inner_product_pd_t *pd,
367
396
bool /* with_relu*/ , bool /* with_eltwise*/ , bool /* with_sum */ ,
368
- bool need_reorder) override {
397
+ bool need_reorder, bool /* use_f32_sum */ ) override {
369
398
need_reorder_ = need_reorder;
370
399
with_bias_ = pd->with_bias ();
371
400
0 commit comments