Skip to content

Commit fd2915e

Browse files
authored
gpu: nvidia: matmul: fix issues with scaling (#2564)
1 parent 447ea75 commit fd2915e

6 files changed

+49
-269
lines changed

src/common/memory_tracking.hpp

-2
Original file line numberDiff line numberDiff line change
@@ -266,8 +266,6 @@ enum {
266266
key_matmul_wei_trans,
267267
key_matmul_dst_trans,
268268
key_matmul_dst_cast_acc,
269-
key_matmul_lt_src_scale,
270-
key_matmul_lt_wei_scale,
271269
key_matmul_sparse_tmp_ptr,
272270
key_pool_dst_bf16cvt,
273271
key_pool_dst_plain2blocked_cvt,

src/gpu/generic/sycl/ref_matmul.hpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,10 @@ struct ref_matmul_t : public gpu::generic::sycl::primitive_t {
122122
const auto &scales = attr()->scales_;
123123
bool dt_ok = true;
124124
for (auto arg : supported_args) {
125-
dt_ok = dt_ok && is_supported_type(scales.get_data_type(arg));
125+
if (!scales.get(arg).has_default_values()) {
126+
dt_ok = dt_ok
127+
&& is_supported_type(scales.get_data_type(arg));
128+
}
126129
}
127130
return dt_ok && attr_scales_ok(supported_args);
128131
}

src/gpu/nvidia/cudnn_matmul.cpp

-56
Original file line numberDiff line numberDiff line change
@@ -66,66 +66,10 @@ status_t cudnn_matmul_lt_t::execute(const exec_ctx_t &ctx) const {
6666
nvidia::stream_t *cuda_stream
6767
= utils::downcast<nvidia::stream_t *>(ctx.stream());
6868

69-
const bool has_src_scales
70-
= ctx.args().find(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC)
71-
!= ctx.args().end();
72-
const bool has_wei_scales
73-
= ctx.args().find(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS)
74-
!= ctx.args().end();
7569
const bool has_dst_scales
7670
= ctx.args().find(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST)
7771
!= ctx.args().end();
7872

79-
if (has_src_scales
80-
&& (pd()->params_->multi_src_scale_
81-
|| pd()->params_->acc_type_ == CUDA_R_32I)) {
82-
// src scale sycl binary
83-
exec_args_t src_scale_binary_args;
84-
src_scale_binary_args[DNNL_ARG_SRC_0]
85-
= memory_arg_t {ctx.args().at(DNNL_ARG_SRC).mem, true};
86-
src_scale_binary_args[DNNL_ARG_SRC_1] = memory_arg_t {
87-
ctx.args().at(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC).mem, true};
88-
89-
std::unique_ptr<memory_t, memory_deleter_t> scratch_mem;
90-
auto scratchpad_storage
91-
= ctx.get_scratchpad_grantor().get_memory_storage(
92-
memory_tracking::names::key_matmul_lt_src_scale);
93-
safe_ptr_assign(scratch_mem,
94-
new memory_t(ctx.stream()->engine(), pd()->src_md(),
95-
std::move(scratchpad_storage)));
96-
src_scale_binary_args[DNNL_ARG_DST]
97-
= memory_arg_t {scratch_mem.get(), false};
98-
99-
exec_ctx_t binary_ctx(ctx, std::move(src_scale_binary_args));
100-
101-
CHECK(src_scale_binary_->execute(binary_ctx));
102-
}
103-
if (has_wei_scales
104-
&& (pd()->params_->multi_wei_scale_
105-
|| pd()->params_->acc_type_ == CUDA_R_32I)) {
106-
// wei scale sycl binary
107-
exec_args_t wei_scale_binary_args;
108-
wei_scale_binary_args[DNNL_ARG_SRC_0]
109-
= memory_arg_t {ctx.args().at(DNNL_ARG_WEIGHTS).mem, true};
110-
wei_scale_binary_args[DNNL_ARG_SRC_1] = memory_arg_t {
111-
ctx.args().at(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS).mem,
112-
true};
113-
114-
std::unique_ptr<memory_t, memory_deleter_t> scratch_mem;
115-
auto scratchpad_storage
116-
= ctx.get_scratchpad_grantor().get_memory_storage(
117-
memory_tracking::names::key_matmul_lt_wei_scale);
118-
safe_ptr_assign(scratch_mem,
119-
new memory_t(ctx.stream()->engine(), pd()->weights_md(0),
120-
std::move(scratchpad_storage)));
121-
wei_scale_binary_args[DNNL_ARG_DST]
122-
= memory_arg_t {scratch_mem.get(), false};
123-
124-
exec_ctx_t binary_ctx(ctx, std::move(wei_scale_binary_args));
125-
126-
CHECK(wei_scale_binary_->execute(binary_ctx));
127-
}
128-
12973
CHECK(executor_->execute(ctx, ctx.stream()->engine(), matmul_impl_,
13074
pd()->params_, src_d, weights_d, dst_d));
13175

src/gpu/nvidia/cudnn_matmul_executor.hpp

+9-62
Original file line numberDiff line numberDiff line change
@@ -245,18 +245,11 @@ struct cudnn_matmul_lt_base_exec_t {
245245
xpu::sycl::interop_memory_arg_t<scratch_m> arg_block_a_scratch,
246246
xpu::sycl::interop_memory_arg_t<scratch_m> arg_block_b_scratch,
247247
xpu::sycl::interop_memory_arg_t<scratch_m> arg_block_c_scratch,
248-
xpu::sycl::interop_memory_arg_t<scratch_m> scaled_arg_src,
249-
xpu::sycl::interop_memory_arg_t<scratch_m> scaled_arg_wt,
250-
xpu::sycl::interop_memory_arg_t<::sycl::access::mode::read>
251-
arg_src_scale,
252-
xpu::sycl::interop_memory_arg_t<::sycl::access::mode::read>
253-
arg_wei_scale,
254248
xpu::sycl::interop_memory_arg_t<::sycl::access::mode::read>
255249
arg_dst_scale,
256250
uint8_t *algo_scratch_ptr, uint8_t *bias_scratch_ptr,
257251
uint8_t *block_a_scratch_ptr, uint8_t *block_b_scratch_ptr,
258-
uint8_t *block_c_scratch_ptr, uint8_t *src_scale_scratch_ptr,
259-
uint8_t *wei_scale_scratch_ptr) {
252+
uint8_t *block_c_scratch_ptr) {
260253

261254
compat::host_task(cgh,
262255
[= WA_THIS_COPY_CAPTURE](const compat::interop_handle &ih) {
@@ -282,29 +275,22 @@ struct cudnn_matmul_lt_base_exec_t {
282275
void *block_c_scratch
283276
= arg_block_c_scratch.get_native_pointer(ih);
284277

285-
void *scaled_src = scaled_arg_src.get_native_pointer(ih);
286-
void *scaled_wt = scaled_arg_wt.get_native_pointer(ih);
287-
288278
void *bias = arg_bias.get_native_pointer(ih);
289279
void *weights = arg_weights.get_native_pointer(ih);
290280
void *src = arg_src.get_native_pointer(ih);
291281
void *dst = arg_dst.get_native_pointer(ih);
292282

293-
void *src_scale = arg_src_scale.get_native_pointer(ih);
294-
void *wei_scale = arg_wei_scale.get_native_pointer(ih);
295283
void *dst_scale = arg_dst_scale.get_native_pointer(ih);
296284

297285
matmul_impl_->execute(cublas_handle, params, weights, src,
298286
dst, bias, algo_scratch, reorder_scratch,
299287
block_a_scratch, block_b_scratch, block_c_scratch,
300-
scaled_src, scaled_wt, src_scale, wei_scale,
301-
dst_scale);
288+
nullptr, nullptr, dst_scale);
302289

303290
free_runtime_scratch(params->has_runtime_params_,
304291
cublas_handle, cuda_stream, algo_scratch_ptr,
305292
bias_scratch_ptr, block_a_scratch_ptr,
306-
block_b_scratch_ptr, block_c_scratch_ptr,
307-
src_scale_scratch_ptr, wei_scale_scratch_ptr);
293+
block_b_scratch_ptr, block_c_scratch_ptr);
308294
if (params->has_runtime_params_) { params->rt_cleanup(); }
309295
});
310296
}
@@ -314,8 +300,7 @@ struct cudnn_matmul_lt_base_exec_t {
314300
cublasHandle_t cublas_handle, nvidia::stream_t *cuda_stream,
315301
uint8_t *algo_scratch_ptr, uint8_t *bias_scratch_ptr,
316302
uint8_t *block_a_scratch_ptr, uint8_t *block_b_scratch_ptr,
317-
uint8_t *block_c_scratch_ptr, uint8_t *src_scale_scratch_ptr,
318-
uint8_t *wei_scale_scratch_ptr) {
303+
uint8_t *block_c_scratch_ptr) {
319304
if (has_runtime_params || bias_scratch_ptr) {
320305
cudaStream_t streamId;
321306
cublasGetStream(cublas_handle, &streamId);
@@ -335,12 +320,6 @@ struct cudnn_matmul_lt_base_exec_t {
335320
if (block_c_scratch_ptr) {
336321
::sycl::free(block_c_scratch_ptr, cuda_stream->queue());
337322
}
338-
if (src_scale_scratch_ptr) {
339-
::sycl::free(src_scale_scratch_ptr, cuda_stream->queue());
340-
}
341-
if (wei_scale_scratch_ptr) {
342-
::sycl::free(wei_scale_scratch_ptr, cuda_stream->queue());
343-
}
344323
}
345324
}
346325

@@ -375,11 +354,6 @@ struct cudnn_matmul_lt_exec_t final : public cudnn_matmul_lt_base_exec_t {
375354
auto arg_bias = CTX_IN_SYCL_MEMORY(DNNL_ARG_BIAS);
376355
auto arg_dst = CTX_OUT_SYCL_MEMORY(DNNL_ARG_DST);
377356

378-
auto arg_src_scale
379-
= CTX_IN_SYCL_MEMORY(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC);
380-
381-
auto arg_wei_scale = CTX_IN_SYCL_MEMORY(
382-
DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS);
383357
auto arg_dst_scale
384358
= CTX_IN_SYCL_MEMORY(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST);
385359
auto arg_algo_scratch = params->algo_scratch_size_ != 0
@@ -407,23 +381,12 @@ struct cudnn_matmul_lt_exec_t final : public cudnn_matmul_lt_base_exec_t {
407381
memory_tracking::names::key_matmul_lt_block_c)
408382
: xpu::sycl::interop_memory_arg_t<
409383
::sycl::access::mode::read_write>();
410-
auto scaled_arg_src = params->src_scale_size_ != 0
411-
? CTX_SCRATCH_SYCL_MEMORY(
412-
memory_tracking::names::key_matmul_lt_src_scale)
413-
: xpu::sycl::interop_memory_arg_t<
414-
::sycl::access::mode::read_write>();
415-
auto scaled_arg_wt = params->wei_scale_size_ != 0
416-
? CTX_SCRATCH_SYCL_MEMORY(
417-
memory_tracking::names::key_matmul_lt_wei_scale)
418-
: xpu::sycl::interop_memory_arg_t<
419-
::sycl::access::mode::read_write>();
420384

421385
interop_task(matmul_impl_, params, engine, cgh, cuda_stream, arg_wt,
422386
arg_src, arg_dst, arg_bias, arg_algo_scratch,
423387
arg_bias_scratch, arg_block_a_scratch, arg_block_b_scratch,
424-
arg_block_c_scratch, scaled_arg_src, scaled_arg_wt,
425-
arg_src_scale, arg_wei_scale, arg_dst_scale, nullptr,
426-
nullptr, nullptr, nullptr, nullptr, nullptr, nullptr);
388+
arg_block_c_scratch, arg_dst_scale, nullptr, nullptr,
389+
nullptr, nullptr, nullptr);
427390
});
428391
}
429392

@@ -465,12 +428,6 @@ struct cudnn_matmul_lt_runtime_args_exec_t final
465428
uint8_t *block_c_scratch_ptr
466429
= alloc_ptr(matmul_params->dest_size_, cuda_stream->queue());
467430

468-
uint8_t *src_scale_scratch_ptr = alloc_ptr(
469-
matmul_params->src_scale_size_, cuda_stream->queue());
470-
471-
uint8_t *wei_scale_scratch_ptr = alloc_ptr(
472-
matmul_params->wei_scale_size_, cuda_stream->queue());
473-
474431
return cuda_stream->interop_task([= WA_THIS_COPY_CAPTURE](
475432
::sycl::handler &cgh) {
476433
auto arg_src = CTX_IN_SYCL_MEMORY(DNNL_ARG_SRC);
@@ -488,26 +445,16 @@ struct cudnn_matmul_lt_runtime_args_exec_t final
488445
matmul_params->weight_size_, block_b_scratch_ptr);
489446
auto arg_block_c_scratch = init_scratch_from_ptr(
490447
matmul_params->dest_size_, block_c_scratch_ptr);
491-
auto scaled_arg_src = init_scratch_from_ptr(
492-
matmul_params->src_scale_size_, src_scale_scratch_ptr);
493-
auto scaled_arg_wt = init_scratch_from_ptr(
494-
matmul_params->wei_scale_size_, wei_scale_scratch_ptr);
495448

496-
auto arg_src_scale
497-
= CTX_IN_SYCL_MEMORY(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC);
498-
auto arg_wei_scale = CTX_IN_SYCL_MEMORY(
499-
DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS);
500449
auto arg_dst_scale
501450
= CTX_IN_SYCL_MEMORY(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST);
502451

503452
interop_task(matmul_impl_, matmul_params, engine, cgh, cuda_stream,
504453
arg_wt, arg_src, arg_dst, arg_bias, arg_algo_scratch,
505454
arg_bias_scratch, arg_block_a_scratch, arg_block_b_scratch,
506-
arg_block_c_scratch, scaled_arg_src, scaled_arg_wt,
507-
arg_src_scale, arg_wei_scale, arg_dst_scale,
508-
algo_scratch_ptr, bias_scratch_ptr, block_a_scratch_ptr,
509-
block_b_scratch_ptr, block_c_scratch_ptr,
510-
src_scale_scratch_ptr, wei_scale_scratch_ptr);
455+
arg_block_c_scratch, arg_dst_scale, algo_scratch_ptr,
456+
bias_scratch_ptr, block_a_scratch_ptr, block_b_scratch_ptr,
457+
block_c_scratch_ptr);
511458
});
512459
}
513460

0 commit comments

Comments
 (0)