Skip to content

Commit d442b85

Browse files
authored
Update TLX groupedGEMM kernel
Differential Revision: D88184206 Pull Request resolved: #678
1 parent c9175c6 commit d442b85

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

tritonbench/operators/grouped_gemm/kernels.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,8 @@ def grouped_matmul_tlx_kernel(
383383
num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M)
384384
num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N)
385385
num_tiles = num_m_tiles * num_n_tiles
386+
num_k_tiles = tl.cdiv(gk, BLOCK_SIZE_K)
387+
386388
if (
387389
tile_idx >= last_problem_end
388390
and tile_idx < last_problem_end + num_tiles
@@ -413,7 +415,6 @@ def grouped_matmul_tlx_kernel(
413415
tile_idx >= last_problem_end
414416
and tile_idx < last_problem_end + num_tiles
415417
):
416-
k = gk
417418
# figure out tile coordinates
418419
tile_idx_in_gemm = tile_idx - last_problem_end
419420
tile_m_idx = tile_idx_in_gemm // num_n_tiles
@@ -423,7 +424,7 @@ def grouped_matmul_tlx_kernel(
423424
offs_am = tile_m_idx * BLOCK_SIZE_M
424425
offs_bn = tile_n_idx * BLOCK_SIZE_N
425426

426-
for kk in range(0, tl.cdiv(k, BLOCK_SIZE_K)):
427+
for kk in range(0, num_k_tiles):
427428
buf, phase = _get_bufidx_phase(accum_cnt, NUM_SMEM_BUFFERS)
428429
tlx.barrier_wait(smem_empty_bars[buf], phase ^ 1)
429430
tlx.barrier_expect_bytes(
@@ -447,6 +448,12 @@ def grouped_matmul_tlx_kernel(
447448
# go to the next tile by advancing NUM_SMS
448449
tile_idx += NUM_SMS
449450

451+
# Wait for the last pair of TMA load to complete before doing
452+
# the TMA desc update for the next gemm problem.
453+
if num_k_tiles > 0:
454+
buf, phase = _get_bufidx_phase(accum_cnt - 1, NUM_SMEM_BUFFERS)
455+
tlx.barrier_wait(smem_full_bars[buf], phase)
456+
450457
# get ready to go to the next gemm problem
451458
last_problem_end = last_problem_end + num_tiles
452459

tritonbench/operators/grouped_gemm/operator.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,11 @@
5252
from .kernels import tlx_group_gemm_fn
5353

5454

55-
IS_B200 = is_cuda() and get_nvidia_gpu_model() == "NVIDIA B200"
55+
IS_B200 = is_cuda() and get_nvidia_gpu_model() in (
56+
"NVIDIA B200",
57+
"NVIDIA GB200",
58+
"NVIDIA GB300",
59+
)
5660

5761

5862
def get_default_shapes():

0 commit comments

Comments
 (0)