@@ -245,18 +245,11 @@ struct cudnn_matmul_lt_base_exec_t {
245
245
xpu::sycl::interop_memory_arg_t <scratch_m> arg_block_a_scratch,
246
246
xpu::sycl::interop_memory_arg_t <scratch_m> arg_block_b_scratch,
247
247
xpu::sycl::interop_memory_arg_t <scratch_m> arg_block_c_scratch,
248
- xpu::sycl::interop_memory_arg_t <scratch_m> scaled_arg_src,
249
- xpu::sycl::interop_memory_arg_t <scratch_m> scaled_arg_wt,
250
- xpu::sycl::interop_memory_arg_t <::sycl::access::mode::read>
251
- arg_src_scale,
252
- xpu::sycl::interop_memory_arg_t <::sycl::access::mode::read>
253
- arg_wei_scale,
254
248
xpu::sycl::interop_memory_arg_t <::sycl::access::mode::read>
255
249
arg_dst_scale,
256
250
uint8_t *algo_scratch_ptr, uint8_t *bias_scratch_ptr,
257
251
uint8_t *block_a_scratch_ptr, uint8_t *block_b_scratch_ptr,
258
- uint8_t *block_c_scratch_ptr, uint8_t *src_scale_scratch_ptr,
259
- uint8_t *wei_scale_scratch_ptr) {
252
+ uint8_t *block_c_scratch_ptr) {
260
253
261
254
compat::host_task (cgh,
262
255
[= WA_THIS_COPY_CAPTURE](const compat::interop_handle &ih) {
@@ -282,29 +275,22 @@ struct cudnn_matmul_lt_base_exec_t {
282
275
void *block_c_scratch
283
276
= arg_block_c_scratch.get_native_pointer (ih);
284
277
285
- void *scaled_src = scaled_arg_src.get_native_pointer (ih);
286
- void *scaled_wt = scaled_arg_wt.get_native_pointer (ih);
287
-
288
278
void *bias = arg_bias.get_native_pointer (ih);
289
279
void *weights = arg_weights.get_native_pointer (ih);
290
280
void *src = arg_src.get_native_pointer (ih);
291
281
void *dst = arg_dst.get_native_pointer (ih);
292
282
293
- void *src_scale = arg_src_scale.get_native_pointer (ih);
294
- void *wei_scale = arg_wei_scale.get_native_pointer (ih);
295
283
void *dst_scale = arg_dst_scale.get_native_pointer (ih);
296
284
297
285
matmul_impl_->execute (cublas_handle, params, weights, src,
298
286
dst, bias, algo_scratch, reorder_scratch,
299
287
block_a_scratch, block_b_scratch, block_c_scratch,
300
- scaled_src, scaled_wt, src_scale, wei_scale,
301
- dst_scale);
288
+ nullptr , nullptr , dst_scale);
302
289
303
290
free_runtime_scratch (params->has_runtime_params_ ,
304
291
cublas_handle, cuda_stream, algo_scratch_ptr,
305
292
bias_scratch_ptr, block_a_scratch_ptr,
306
- block_b_scratch_ptr, block_c_scratch_ptr,
307
- src_scale_scratch_ptr, wei_scale_scratch_ptr);
293
+ block_b_scratch_ptr, block_c_scratch_ptr);
308
294
if (params->has_runtime_params_ ) { params->rt_cleanup (); }
309
295
});
310
296
}
@@ -314,8 +300,7 @@ struct cudnn_matmul_lt_base_exec_t {
314
300
cublasHandle_t cublas_handle, nvidia::stream_t *cuda_stream,
315
301
uint8_t *algo_scratch_ptr, uint8_t *bias_scratch_ptr,
316
302
uint8_t *block_a_scratch_ptr, uint8_t *block_b_scratch_ptr,
317
- uint8_t *block_c_scratch_ptr, uint8_t *src_scale_scratch_ptr,
318
- uint8_t *wei_scale_scratch_ptr) {
303
+ uint8_t *block_c_scratch_ptr) {
319
304
if (has_runtime_params || bias_scratch_ptr) {
320
305
cudaStream_t streamId;
321
306
cublasGetStream (cublas_handle, &streamId);
@@ -335,12 +320,6 @@ struct cudnn_matmul_lt_base_exec_t {
335
320
if (block_c_scratch_ptr) {
336
321
::sycl::free (block_c_scratch_ptr, cuda_stream->queue ());
337
322
}
338
- if (src_scale_scratch_ptr) {
339
- ::sycl::free (src_scale_scratch_ptr, cuda_stream->queue ());
340
- }
341
- if (wei_scale_scratch_ptr) {
342
- ::sycl::free (wei_scale_scratch_ptr, cuda_stream->queue ());
343
- }
344
323
}
345
324
}
346
325
@@ -375,11 +354,6 @@ struct cudnn_matmul_lt_exec_t final : public cudnn_matmul_lt_base_exec_t {
375
354
auto arg_bias = CTX_IN_SYCL_MEMORY (DNNL_ARG_BIAS);
376
355
auto arg_dst = CTX_OUT_SYCL_MEMORY (DNNL_ARG_DST);
377
356
378
- auto arg_src_scale
379
- = CTX_IN_SYCL_MEMORY (DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC);
380
-
381
- auto arg_wei_scale = CTX_IN_SYCL_MEMORY (
382
- DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS);
383
357
auto arg_dst_scale
384
358
= CTX_IN_SYCL_MEMORY (DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST);
385
359
auto arg_algo_scratch = params->algo_scratch_size_ != 0
@@ -407,23 +381,12 @@ struct cudnn_matmul_lt_exec_t final : public cudnn_matmul_lt_base_exec_t {
407
381
memory_tracking::names::key_matmul_lt_block_c)
408
382
: xpu::sycl::interop_memory_arg_t <
409
383
::sycl::access ::mode::read_write>();
410
- auto scaled_arg_src = params->src_scale_size_ != 0
411
- ? CTX_SCRATCH_SYCL_MEMORY (
412
- memory_tracking::names::key_matmul_lt_src_scale)
413
- : xpu::sycl::interop_memory_arg_t <
414
- ::sycl::access ::mode::read_write>();
415
- auto scaled_arg_wt = params->wei_scale_size_ != 0
416
- ? CTX_SCRATCH_SYCL_MEMORY (
417
- memory_tracking::names::key_matmul_lt_wei_scale)
418
- : xpu::sycl::interop_memory_arg_t <
419
- ::sycl::access ::mode::read_write>();
420
384
421
385
interop_task (matmul_impl_, params, engine, cgh, cuda_stream, arg_wt,
422
386
arg_src, arg_dst, arg_bias, arg_algo_scratch,
423
387
arg_bias_scratch, arg_block_a_scratch, arg_block_b_scratch,
424
- arg_block_c_scratch, scaled_arg_src, scaled_arg_wt,
425
- arg_src_scale, arg_wei_scale, arg_dst_scale, nullptr ,
426
- nullptr , nullptr , nullptr , nullptr , nullptr , nullptr );
388
+ arg_block_c_scratch, arg_dst_scale, nullptr , nullptr ,
389
+ nullptr , nullptr , nullptr );
427
390
});
428
391
}
429
392
@@ -465,12 +428,6 @@ struct cudnn_matmul_lt_runtime_args_exec_t final
465
428
uint8_t *block_c_scratch_ptr
466
429
= alloc_ptr (matmul_params->dest_size_ , cuda_stream->queue ());
467
430
468
- uint8_t *src_scale_scratch_ptr = alloc_ptr (
469
- matmul_params->src_scale_size_ , cuda_stream->queue ());
470
-
471
- uint8_t *wei_scale_scratch_ptr = alloc_ptr (
472
- matmul_params->wei_scale_size_ , cuda_stream->queue ());
473
-
474
431
return cuda_stream->interop_task ([= WA_THIS_COPY_CAPTURE](
475
432
::sycl::handler &cgh) {
476
433
auto arg_src = CTX_IN_SYCL_MEMORY (DNNL_ARG_SRC);
@@ -488,26 +445,16 @@ struct cudnn_matmul_lt_runtime_args_exec_t final
488
445
matmul_params->weight_size_ , block_b_scratch_ptr);
489
446
auto arg_block_c_scratch = init_scratch_from_ptr (
490
447
matmul_params->dest_size_ , block_c_scratch_ptr);
491
- auto scaled_arg_src = init_scratch_from_ptr (
492
- matmul_params->src_scale_size_ , src_scale_scratch_ptr);
493
- auto scaled_arg_wt = init_scratch_from_ptr (
494
- matmul_params->wei_scale_size_ , wei_scale_scratch_ptr);
495
448
496
- auto arg_src_scale
497
- = CTX_IN_SYCL_MEMORY (DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC);
498
- auto arg_wei_scale = CTX_IN_SYCL_MEMORY (
499
- DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS);
500
449
auto arg_dst_scale
501
450
= CTX_IN_SYCL_MEMORY (DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST);
502
451
503
452
interop_task (matmul_impl_, matmul_params, engine, cgh, cuda_stream,
504
453
arg_wt, arg_src, arg_dst, arg_bias, arg_algo_scratch,
505
454
arg_bias_scratch, arg_block_a_scratch, arg_block_b_scratch,
506
- arg_block_c_scratch, scaled_arg_src, scaled_arg_wt,
507
- arg_src_scale, arg_wei_scale, arg_dst_scale,
508
- algo_scratch_ptr, bias_scratch_ptr, block_a_scratch_ptr,
509
- block_b_scratch_ptr, block_c_scratch_ptr,
510
- src_scale_scratch_ptr, wei_scale_scratch_ptr);
455
+ arg_block_c_scratch, arg_dst_scale, algo_scratch_ptr,
456
+ bias_scratch_ptr, block_a_scratch_ptr, block_b_scratch_ptr,
457
+ block_c_scratch_ptr);
511
458
});
512
459
}
513
460
0 commit comments