Skip to content

Commit 3a5e741

Browse files
Peng Chen (Dev Infra)facebook-github-bot
authored andcommitted
Fix lint errors from prev diff (#4870)
Summary: X-link: facebookresearch/FBGEMM#1892 Pull Request resolved: #4870 tsia Reviewed By: htyu Differential Revision: D82337107 fbshipit-source-id: 8129eeece000e5aaae14d4a90a460451aa0a0ff1
1 parent e699b6f commit 3a5e741

File tree

1 file changed

+4
-9
lines changed

1 file changed

+4
-9
lines changed

fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,6 @@ def _fbgemm_grouped_gemm(
253253
tidx = tl.program_id(0)
254254

255255
dtype: tl.dtype = c_ptr.dtype.element_ty
256-
c_desc_ptr = None
257256

258257
M_end_offset = 0
259258
M_end_offset = M_end_offset.to(tl.int64) # pyre-ignore
@@ -273,10 +272,10 @@ def _fbgemm_grouped_gemm(
273272
num_tiles = num_m_tiles * num_n_tiles
274273

275274
if USE_TMA_STORE:
276-
# pyre-ignore
277275
c_desc_ptr = tl.make_tensor_descriptor(
278276
c_ptr + M_start_offset * N,
279277
shape=[m_size, n_size],
278+
# pyre-ignore
280279
strides=[n_size, 1],
281280
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
282281
)
@@ -422,7 +421,6 @@ def _fbgemm_grouped_gemm_ws(
422421
tidx = tl.program_id(0)
423422

424423
dtype: tl.dtype = c_ptr.dtype.element_ty
425-
c_desc_ptr = None
426424

427425
M_end_offset = 0
428426
M_end_offset = M_end_offset.to(tl.int64) # pyre-ignore
@@ -443,10 +441,10 @@ def _fbgemm_grouped_gemm_ws(
443441

444442
if USE_TMA_STORE:
445443
with tl.async_task([0]):
446-
# pyre-ignore
447444
c_desc_ptr = tl.make_tensor_descriptor(
448445
c_ptr + M_start_offset * N,
449446
shape=[m_size, N],
447+
# pyre-ignore
450448
strides=[N, 1],
451449
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
452450
)
@@ -586,7 +584,6 @@ def _fbgemm_grouped_gemm_fp8_rowwise(
586584
tidx = tl.program_id(0)
587585

588586
dtype = TT_FP8_DTYPE
589-
c_desc_ptr = None
590587

591588
M_end_offset = 0
592589
M_end_offset = M_end_offset.to(tl.int64) # pyre-ignore
@@ -606,10 +603,10 @@ def _fbgemm_grouped_gemm_fp8_rowwise(
606603
num_tiles = num_m_tiles * num_n_tiles
607604

608605
if USE_TMA_STORE:
609-
# pyre-ignore
610606
c_desc_ptr = tl.make_tensor_descriptor(
611607
c_ptr + M_start_offset * N,
612608
shape=[m_size, n_size],
609+
# pyre-ignore
613610
strides=[n_size, 1],
614611
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
615612
)
@@ -756,7 +753,6 @@ def _fbgemm_grouped_gemm_fp8_rowwise_ws(
756753
tidx = tl.program_id(0)
757754

758755
dtype = TT_FP8_DTYPE
759-
c_desc_ptr = None
760756

761757
M_end_offset = 0
762758
M_end_offset = M_end_offset.to(tl.int64) # pyre-ignore
@@ -777,10 +773,10 @@ def _fbgemm_grouped_gemm_fp8_rowwise_ws(
777773

778774
if USE_TMA_STORE:
779775
with tl.async_task([0]):
780-
# pyre-ignore
781776
c_desc_ptr = tl.make_tensor_descriptor(
782777
c_ptr + M_start_offset * N,
783778
shape=[m_size, N],
779+
# pyre-ignore
784780
strides=[N, 1],
785781
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
786782
)
@@ -1003,7 +999,6 @@ def _grouped_gemm(
1003999
desc_x = x
10041000
desc_w = w
10051001
desc_ws = w_scale
1006-
workspace = None
10071002

10081003
if USE_TMA_LOAD:
10091004
desc_helper = utils.TmaAutoTuneHelper()

0 commit comments

Comments
 (0)