@@ -253,7 +253,6 @@ def _fbgemm_grouped_gemm(
253
253
tidx = tl .program_id (0 )
254
254
255
255
dtype : tl .dtype = c_ptr .dtype .element_ty
256
- c_desc_ptr = None
257
256
258
257
M_end_offset = 0
259
258
M_end_offset = M_end_offset .to (tl .int64 ) # pyre-ignore
@@ -273,10 +272,10 @@ def _fbgemm_grouped_gemm(
273
272
num_tiles = num_m_tiles * num_n_tiles
274
273
275
274
if USE_TMA_STORE :
276
- # pyre-ignore
277
275
c_desc_ptr = tl .make_tensor_descriptor (
278
276
c_ptr + M_start_offset * N ,
279
277
shape = [m_size , n_size ],
278
+ # pyre-ignore
280
279
strides = [n_size , 1 ],
281
280
block_shape = [BLOCK_SIZE_M , BLOCK_SIZE_N ],
282
281
)
@@ -422,7 +421,6 @@ def _fbgemm_grouped_gemm_ws(
422
421
tidx = tl .program_id (0 )
423
422
424
423
dtype : tl .dtype = c_ptr .dtype .element_ty
425
- c_desc_ptr = None
426
424
427
425
M_end_offset = 0
428
426
M_end_offset = M_end_offset .to (tl .int64 ) # pyre-ignore
@@ -443,10 +441,10 @@ def _fbgemm_grouped_gemm_ws(
443
441
444
442
if USE_TMA_STORE :
445
443
with tl .async_task ([0 ]):
446
- # pyre-ignore
447
444
c_desc_ptr = tl .make_tensor_descriptor (
448
445
c_ptr + M_start_offset * N ,
449
446
shape = [m_size , N ],
447
+ # pyre-ignore
450
448
strides = [N , 1 ],
451
449
block_shape = [BLOCK_SIZE_M , BLOCK_SIZE_N ],
452
450
)
@@ -586,7 +584,6 @@ def _fbgemm_grouped_gemm_fp8_rowwise(
586
584
tidx = tl .program_id (0 )
587
585
588
586
dtype = TT_FP8_DTYPE
589
- c_desc_ptr = None
590
587
591
588
M_end_offset = 0
592
589
M_end_offset = M_end_offset .to (tl .int64 ) # pyre-ignore
@@ -606,10 +603,10 @@ def _fbgemm_grouped_gemm_fp8_rowwise(
606
603
num_tiles = num_m_tiles * num_n_tiles
607
604
608
605
if USE_TMA_STORE :
609
- # pyre-ignore
610
606
c_desc_ptr = tl .make_tensor_descriptor (
611
607
c_ptr + M_start_offset * N ,
612
608
shape = [m_size , n_size ],
609
+ # pyre-ignore
613
610
strides = [n_size , 1 ],
614
611
block_shape = [BLOCK_SIZE_M , BLOCK_SIZE_N ],
615
612
)
@@ -756,7 +753,6 @@ def _fbgemm_grouped_gemm_fp8_rowwise_ws(
756
753
tidx = tl .program_id (0 )
757
754
758
755
dtype = TT_FP8_DTYPE
759
- c_desc_ptr = None
760
756
761
757
M_end_offset = 0
762
758
M_end_offset = M_end_offset .to (tl .int64 ) # pyre-ignore
@@ -777,10 +773,10 @@ def _fbgemm_grouped_gemm_fp8_rowwise_ws(
777
773
778
774
if USE_TMA_STORE :
779
775
with tl .async_task ([0 ]):
780
- # pyre-ignore
781
776
c_desc_ptr = tl .make_tensor_descriptor (
782
777
c_ptr + M_start_offset * N ,
783
778
shape = [m_size , N ],
779
+ # pyre-ignore
784
780
strides = [N , 1 ],
785
781
block_shape = [BLOCK_SIZE_M , BLOCK_SIZE_N ],
786
782
)
@@ -1003,7 +999,6 @@ def _grouped_gemm(
1003
999
desc_x = x
1004
1000
desc_w = w
1005
1001
desc_ws = w_scale
1006
- workspace = None
1007
1002
1008
1003
if USE_TMA_LOAD :
1009
1004
desc_helper = utils .TmaAutoTuneHelper ()
0 commit comments